carrot/tinygrad_repo/extra/replay_pkl.py

76 lines
2.7 KiB
Python
Raw Permalink Normal View History

2025-04-18 20:38:55 +09:00
import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context, Tensor, GlobalCounters
2025-04-18 20:38:55 +09:00
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
2025-04-18 20:38:55 +09:00
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.codegen.heuristic import hand_coded_optimizations
import numpy as np
def move_jit_captured_to_dev(captured, device="DSP"):
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
assign = {}
def move_buffer(b):
if b in assign: return assign[b]
if b._base is not None:
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
else:
newbuf = Buffer(device, b.size, b.dtype)
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
assign[b] = newbuf
return assign[b]
for item in captured.jit_cache:
for b in item.bufs:
if b is not None: move_buffer(b)
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
return captured
2025-04-18 20:38:55 +09:00
if __name__ == "__main__":
with Context(DEBUG=0):
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
# Move all buffers to DSP device.
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
new_jit = []
2025-04-18 20:38:55 +09:00
knum = 1
for ei in fxn.captured.jit_cache:
# skip the copy and the first kernel
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
if getenv("VALIDATE"):
with Context(NOOPT=1):
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
correct = ei.bufs[0].numpy()
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
GlobalCounters.kernel_count -= 1
if not getenv("NOOPT"): k.apply_opts(hand_coded_optimizations(k))
2025-04-18 20:38:55 +09:00
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2))
2025-04-18 20:38:55 +09:00
new_ei.run()
new_jit.append(new_ei)
test = ei.bufs[0].numpy()
if getenv("VALIDATE"):
import numpy as np
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
2025-04-18 20:38:55 +09:00
knum += 1
if getenv("RUN_JIT", 0):
fxn.captured.free_intermediates()
fxn.captured.jit_cache = new_jit
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))