carrot/tinygrad_repo/extra/replay_pkl.py
2025-04-19 08:05:49 +09:00

31 lines
1.1 KiB
Python

import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
if __name__ == "__main__":
with Context(DEBUG=0):
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
knum = 1
for ei in fxn.captured.jit_cache:
# skip the copy and the first kernel
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
k.hand_coded_optimizations()
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
new_ei.run()
knum += 1