128 lines
5.4 KiB
Python
Raw Permalink Normal View History

from typing import Dict, Set
import yaml
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
from tinygrad.uop.ops import BinaryOps
from tinygrad.dtype import dtypes
def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
replace: Dict[UOp, UOp] = {}
seen: Set[UOp] = set()
for u in uops:
if u in seen: continue
seen.add(u)
for o,n in replace.items():
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
# pointer indexing
if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
u.vin = (u.vin[0], ptr) + u.vin[2:]
#uops.print()
args = []
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
r: Dict[UOp, str] = {}
for u in uops:
if u.uop == UOps.SPECIAL:
if u.arg[1].startswith("lidx"):
r[u] = f'v{u.arg[0]}'
elif u.arg[1].startswith("gidx"):
r[u] = f's{2+u.arg[0]}'
else:
raise NotImplementedError
elif u.uop == UOps.CONST:
#r[u] = u.arg
# TODO: sometimes we can use s
#r[u] = f"s{s_cnt}"
#s_cnt += 1
#ins.append(f"s_mov_b32 {r[u]}, {u.arg}")
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"v_mov_b32 {r[u]}, {u.arg}")
elif u.uop == UOps.ALU:
if u.arg == BinaryOps.ADD:
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"v_add_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
elif u.arg == BinaryOps.MUL:
r[u] = f"v{v_cnt}"
v_cnt += 1
if dtypes.is_float(u.dtype):
ins.append(f"v_mul_f32_e32 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
else:
ins.append(f"v_mul_u32_u24 {r[u]}, {r[u.vin[0]]}, {r[u.vin[1]]}")
else:
raise NotImplementedError
elif u.uop == UOps.LOAD:
r[u] = f"v{v_cnt}"
v_cnt += 1
ins.append(f"global_load_b32 {r[u]}, {r[u.vin[1]]}, {r[u.vin[0]]}")
ins.append("s_waitcnt vmcnt(0)")
elif u.uop == UOps.STORE:
ins.append(f"global_store_b32 {r[u.vin[1]]}, {r[u.vin[2]]}, {r[u.vin[0]]}")
elif u.uop == UOps.DEFINE_GLOBAL:
i = u.arg[0]
args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8,
'.type_name': u.dtype.name+"*", '.value_kind': 'global_buffer'})
s_cnt += s_cnt%2 # skip
r[u] = f"s[{s_cnt}:{s_cnt+1}]"
s_cnt += 2
ins.append(f"s_load_b64 {r[u]}, s[0:1], {i*8}")
ins.append("s_waitcnt lgkmcnt(0)")
else:
raise NotImplementedError(f"can't render {u.uop}")
# *** boilerplate rendering ***
metadata = {
'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': function_name, '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': f'{function_name}.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
boilerplate_start = f"""
.rodata
.global {function_name}.kd
.type {function_name}.kd,STT_OBJECT
.align 0x10
.amdhsa_kernel {function_name}"""
kernel_desc = {
'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3,
'.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, '.amdhsa_fp16_overflow': 0,
'.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0,
'.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0,
'.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
code_start = f""".end_amdhsa_kernel
.text
.global {function_name}
.type {function_name},@function
.p2align 8
{function_name}:
"""
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
return ".amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" + \
boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + \
'\n'.join(ins) + f"\n.size {function_name}, .-{function_name}"