carrot/tinygrad_repo/extra/gemm/tinygrad_nv_matmul.py

38 lines
1.3 KiB
Python
Raw Normal View History

from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, DEBUG
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
from dataclasses import replace
N = 4096
if __name__ == "__main__":
if getenv("GEMV"):
A, B = Tensor.empty(1, N, dtype=dtypes.float), Tensor.empty(14336, N, dtype=dtypes.float16).T
else:
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
C = A.matmul(B)
si = C.schedule()[-1]
ast = si.ast
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
if getenv("GEMV"):
opts = [
Opt(op=OptOps.UNROLL, axis=0, amt=8),
Opt(op=OptOps.GROUP, axis=0, amt=32),
]
else:
opts = [
Opt(op=OptOps.TC, axis=0, amt=0),
Opt(op=OptOps.UPCAST, axis=0, amt=4),
Opt(op=OptOps.UPCAST, axis=1, amt=8),
Opt(op=OptOps.LOCAL, axis=0, amt=2),
Opt(op=OptOps.LOCAL, axis=1, amt=2),
Opt(op=OptOps.LOCAL, axis=0, amt=2),
]
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
new_src = prg.src
# can mod source here
prg = replace(prg, src=new_src)
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
for i in range(5): ei.run(wait=True)