#!/usr/bin/env python3 import numpy as np import time import sys np.set_printoptions(linewidth=160) np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False) from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const from llvmlite import ir # type: ignore # https://github.com/corsix/amx/blob/main/Instructions.md # 12 lines for AMX support from functools import partialmethod class AMX: @staticmethod def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True) @staticmethod def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True) set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1) ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3) ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7) extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9) fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13) mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16) vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22) N = 4096 #N = 1024 #N = 64 #an = np.arange(N*N).reshape(N, N) - 43*64 #bn = np.arange(N*N).reshape(N, N) #an = np.ones((N, N)).astype(np.float32) #bn = np.ones((N, N)).astype(np.float32) # matrix is 64M, max load bandwidth is 57 GB/s # cache line looks like 256 bytes (64 floats) an = np.random.randn(N, N) bn = np.random.randn(N, N) an = an.astype(np.float32) bn = bn.astype(np.float32) sn = an.reshape(-1, 32).sum(axis=0) cn = (an.T @ bn).T a = LLVMBuffer.fromCPU(an) b = LLVMBuffer.fromCPU(bn) #c = LLVMBuffer.fromCPU(np.zeros((N, N))) c = LLVMBuffer.fromCPU(np.zeros(256)) bufs = [c,a,b] module = ir.Module(name=__file__) func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec') # load all entry = ir.IRBuilder(func.append_basic_block(name="entry")) zm, xm, ym = [entry.ptrtoint(func.args[i], ir.IntType(64)) for i in range(3)] loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y")) loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit")) exit = ir.IRBuilder(func.append_basic_block(name="exit")) y = loop_1.phi(ir.IntType(64), name="y") y.add_incoming(int_const(0), entry._block) yp = loop_1_exit.add(y, int_const(32*2)) y.add_incoming(yp, loop_1_exit._block) prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch") xptr = y addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr)) #prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType())) #loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)]) AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1<<62), addr)) xptr = loop_1_exit.add(xptr, int_const(32)) AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr)))) AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28)) AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16*4)<<10)) AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29)) AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16*4))) AMX.set(entry) AMX.stz(exit, exit.add(zm, int_const(1 << 62 | (0 << 56) | 0))) AMX.clr(exit) entry.branch(loop_1._block) loop_1.branch(loop_1_exit._block) loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block) exit.ret(int_const(0)) cfunc = LLVM().exec(module, bufs, N**2) """ loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y")) loop_2 = ir.IRBuilder(func.append_basic_block(name="loop_x")) loop_3 = ir.IRBuilder(func.append_basic_block(name="loop_k")) loop_3_exit = ir.IRBuilder(func.append_basic_block(name="loop_k_exit")) loop_2_exit = ir.IRBuilder(func.append_basic_block(name="loop_x_exit")) loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit")) y = loop_1.phi(ir.IntType(64), name="y") x = loop_2.phi(ir.IntType(64), name="x") k = loop_3.phi(ir.IntType(64), name="k") exit = ir.IRBuilder(func.append_basic_block(name="exit")) AMX.set(loop_2) # stride xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(N))) yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(N))) # if you are okay with the wrong answer, this is faster #xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(32))) #yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(32))) # double loads load 32 floats AMX.ldx(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(xm, loop_3_exit.mul(int_const(4), xptr)))) AMX.ldy(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(ym, loop_3_exit.mul(int_const(4), yptr)))) # AMX.fma32(loop_3_exit, int_const(0<<20 | (0*16*4)<<10 | (0*16*4))) AMX.fma32(loop_3_exit, int_const(1<<20 | (1*16*4)<<10 | (0*16*4))) AMX.fma32(loop_3_exit, int_const(2<<20 | (0*16*4)<<10 | (1*16*4))) AMX.fma32(loop_3_exit, int_const(3<<20 | (1*16*4)<<10 | (1*16*4))) # store gptr = loop_2_exit.mul(loop_2_exit.add(loop_2.mul(y, int_const(N)), x), int_const(4)) zmp = loop_2_exit.add(zm, gptr) for j in range(2): for r in range(16): z_row = j*2 ptr = ((j*16)+r)*N AMX.stz(loop_2_exit, loop_2_exit.add(zmp, int_const(1 << 62 | ((r*4+z_row) << 56) | ptr*4))) AMX.clr(loop_2_exit) yp = loop_1_exit.add(y, int_const(32)) xp = loop_2_exit.add(x, int_const(32)) kp = loop_3_exit.add(k, int_const(1)) y.add_incoming(int_const(0), entry._block) x.add_incoming(int_const(0), loop_1._block) k.add_incoming(int_const(0), loop_2._block) y.add_incoming(yp, loop_1_exit._block) x.add_incoming(xp, loop_2_exit._block) k.add_incoming(kp, loop_3_exit._block) entry.branch(loop_1._block) loop_1.branch(loop_2._block) loop_2.branch(loop_3._block) loop_3.branch(loop_3_exit._block) loop_3_exit.cbranch(loop_3_exit.icmp_unsigned("==", kp, int_const(N)), loop_2_exit._block, loop_3._block) loop_2_exit.cbranch(loop_2_exit.icmp_unsigned("==", xp, int_const(N)), loop_1_exit._block, loop_2._block) loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N)), exit._block, loop_1._block) exit.ret(int_const(0)) cfunc = LLVM().exec(module, bufs, N**3 * 2) """ times = [] for i in range(50): st = time.monotonic() cfunc(*[x._buf for x in bufs]) et = time.monotonic() - st times.append(et) print(f"{min(times)*1000:.2f} ms min time, {np.median(times)*1000:.2f} ms median time") print("%.2f GB/s" % ((N*N*4*1e-9)/min(times))) print(c.toCPU().astype(np.int64)[:sn.shape[0]]) print(sn.astype(np.int64)) np.testing.assert_allclose(c.toCPU()[:sn.shape[0]], sn, atol=1e-4, rtol=1e-4) """ print(cn.astype(np.int64)) np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5) """