131 lines
7.0 KiB
Python
Raw Permalink Normal View History

from typing import Dict, List, Final, Callable, DefaultDict
from collections import defaultdict
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
2025-01-07 19:31:23 +09:00
from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.codegen.linearizer import UOp, UOps
from triton.compiler import compile as triton_compile # type: ignore
import linecache
import math
import re
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
2025-01-07 19:31:23 +09:00
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"}
def next_power_of_2(x):
return 1 << (x - 1).bit_length()
def render_valid(valid):
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
def fill_dims_for_idx(idx, dims):
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
def get_max(var):
if isinstance(var, int): return var
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
def remove_single_scalar_curly_braces(ptx_code):
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
2025-01-07 19:31:23 +09:00
def render_const(args):
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else str(args))
2025-01-07 19:31:23 +09:00
def render_cast(x:str, dtype:DType):
return f"{x}.to({triton_dtypes[dtype]})"
def define_scalar(local_size, dtype, args):
2025-01-07 19:31:23 +09:00
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args)}, dtype={triton_dtypes[dtype]})"
return render_const(args)
def uops_to_triton(function_name:str, uops:List[UOp]):
local_size: List[int] = []
depth = 1
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
nonlocal c, r
c[prefix] += 1
r[u]=f"{prefix}{c[prefix]-1}"
return r[u]
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
def kk(s): kernel.append(" "*depth+s)
code_for_op: Final[Dict[Op, Callable]] = {
2025-01-07 19:31:23 +09:00
UnaryOps.EXP2: lambda x,: f"tl.math.exp2({x})",
UnaryOps.LOG2: lambda x,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,: f"-{x}",
BinaryOps.ADD: lambda x,y,: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,: f"tl.maximum({x},{y})",
BinaryOps.CMPLT: lambda x,y,: f"({x}<{y})",
BinaryOps.MOD: lambda x,y,: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
TernaryOps.MULACC: lambda x,y,z,: f"(({x}*{y})+{z})",
TernaryOps.WHERE: lambda x,y,z,: f"tl.where({x},{y},{z})",
}
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
for u in uops:
2025-01-07 19:31:23 +09:00
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
depth += 1
2025-01-07 19:31:23 +09:00
elif uop == UOps.END: depth -= 1
elif uop == UOps.ALU:
assert dtype is not None
val = code_for_op[args](*[r[x] for x in vin])
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
else: kk(f"{ssa(u, 'alu')} = ({val})")
2025-01-07 19:31:23 +09:00
elif uop == UOps.LOAD:
assert dtype is not None
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
2025-01-07 19:31:23 +09:00
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
r[u] = r[vin[0]]
2025-01-07 19:31:23 +09:00
elif uop == UOps.STORE:
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
2025-01-07 19:31:23 +09:00
kk(f"tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
2025-01-07 19:31:23 +09:00
signatures.append(signature_dtypes[args[1]])
r[u] = args[0]
elif uop == UOps.SPECIAL:
dims.append(args[1])
valid.append(f"{args[1]}<{get_max(args[2])}")
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
elif args[1].startswith("l"):
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
local_size.append(args[2])
r[u] = args[1]
else: raise NotImplementedError(f"unimplemented: {uop}")
2025-01-07 19:31:23 +09:00
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
prg += "\n".join(kernel)
acc_local_size = 1
for x in local_size: acc_local_size *= next_power_of_2(x)
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
if DEBUG >= 4: print(prg)
getlines = linecache.getlines
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}