148 lines
8.4 KiB
Python
Raw Normal View History

from typing import Final, Dict, Callable, Any, List, Optional, Tuple
from llvmlite import ir # type: ignore
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),
# TODO: this should be casted
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y),
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=LLVM_FAST_MATH_FLAGS),
}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
if output_type == dtypes.float32:
if dtypes.is_int(input_type) or input_type == dtypes.bool:
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
elif input_type == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
elif input_type == dtypes.float64:
val = bb[-1].fptrunc(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
return val
if input_type == dtypes.float32:
if dtypes.is_int(output_type) or output_type == dtypes.bool:
if dtypes.is_unsigned(output_type): val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type])
elif output_type == dtypes.bool: val = bb[-1].fcmp_ordered("!=", val, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS)
else: val = bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
elif output_type == dtypes.bfloat16:
val = bb[-1].bitcast(val, ir.IntType(32))
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].trunc(val, ir.IntType(16))
elif output_type == dtypes.float64:
val = bb[-1].fpext(val, ir.DoubleType())
else:
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
return val
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
# extract global buffers
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
for bufname,dtype in buf_to_dtype.items():
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1]._block)
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(ir.FloatType())
lvars[rp].add_incoming(incoming, bb[-2]._block)
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block)
loop_blocks.append((bb[-1], phis))
if uop == UOps.END:
block, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args[0]]]
if uop == UOps.DEFINE_ACC:
lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], args)
reduce_phis.append(u)
if uop == UOps.SPECIAL:
lvars[u] = lvars[args.expr]
if uop == UOps.CONST:
value = int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args
lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], value)
if uop == UOps.LOAD:
assert dtype is not None
if len(vin) > 2:
gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = cast(bb, val, vin[0].dtype, dtype)
val = bb[-1].select(gate, val, lvars[vin[3]])
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
val = cast(bb, val, vin[0].dtype, dtype)
lvars[u] = val
if uop == UOps.PHI:
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop == UOps.PHI: backward = backward.vin[0]
lvars[backward] = lvars[u]
if uop == UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
if uop == UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
bb[-1].ret_void()
return str(module), {}