carrot/tinygrad_repo/examples/handcode_resnet50_opt.py

73 lines
2.9 KiB
Python
Raw Normal View History

from typing import List
from models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer, beam_search
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
if __name__ == "__main__":
mdl = ResNet50()
seen = set()
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
print(f"optimizing for {Device.DEFAULT}")
# first model run to init the weights, they are saved in seen
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = [x for x in sched if x.ast.op not in LoadOps]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
# work with the schedule
total_tm = 0
running_gflops = 0
for i,si in enumerate(sched):
# create output/input buffers (NOTE: bufs_from_lin is slower, so we don't use it. TODO: fix)
rawbufs = [device.buffer(si.out.st.size(), si.out.dtype)] + [device.buffer(x.st.size(), x.dtype) for x in si.inputs]
#rawbufs = bufs_from_lin(lin)
# "linearize" the op into uops in different ways
lins:List[Linearizer] = []
# always try hand coded opt
lin = Linearizer(si.ast, device.linearizer_opts)
lin.hand_coded_optimizations()
lins.append(lin)
# maybe try tensor cores
lin = Linearizer(si.ast, device.linearizer_opts)
if lin.apply_tensor_cores():
lins.append(lin)
# try a beam search
if getenv("BEAM"):
lin = Linearizer(si.ast, device.linearizer_opts)
lin = beam_search(lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1)))
lins.append(lin)
# benchmark the programs
choices = []
for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
choices.append((tm, gflops, lin.linearize()))
# print all kernels
if DEBUG >= 1: print(f" kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
total_tm += tm
running_gflops += gflops * tm
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")