carrot/tinygrad_repo/examples/handcode_opt.py
carrot efee1712aa
KerryGoldModel, AGNOS12.3, ButtonMode3, autoDetectLFA2, (#181)
* fix.. speed_limit error...

* draw tpms settings.

* fix.. traffic light stopping only..

* fix.. waze cam

* fix.. waze...

* add setting (Enable comma connect )

* auto detect LFA2

* fix.. cruisespeed1

* vff2 driving model.

* fix..

* agnos 12.3

* fix..

* ff

* ff

* test

* ff

* fix.. drawTurnInfo..

* Update drive_helpers.py

* fix..

support eng  voice

eng sounds

fix settings... english

fix.. mph..

fix.. roadlimit speed bug..

* new vff model.. 250608

* fix soundd..

* fix safe exit speed..

* fix.. sounds.

* fix.. radar timeStep..

* KerryGold model

* Update drive_helpers.py

* fix.. model.

* fix..

* fix..

* Revert "fix.."

This reverts commit b09ec459afb855c533d47fd7e8a1a6b1a09466e7.

* Revert "fix.."

This reverts commit 290bec6b83a4554ca232d531a911edccf94a2156.

* fix esim

* add more acc table. 10kph

* kg update..

* fix cruisebutton mode3

* test atc..cond.

* fix.. canfd

* fix.. angle control limit
2025-06-13 15:59:36 +09:00

134 lines
5.3 KiB
Python

from extra.models.resnet import ResNet50
from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
from extra.optimization.helpers import time_linearizer
def get_sched_resnet():
mdl = ResNet50()
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
BS = getenv("BS", 64)
# run model twice to get only what changes, these are the kernels of the model
for _ in range(2):
out = mdl(Tensor.empty(BS, 3, 224, 224))
targets = [out]
if getenv("BACKWARD"):
optim.zero_grad()
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
targets += [x for x in optim.schedule_step()]
sched = Tensor.schedule(*targets)
print(f"schedule length {len(sched)}")
return sched
def get_sched_bert():
mdl = get_mlperf_bert_model()
optim = nn.optim.LAMB(nn.state.get_parameters(mdl))
# fake data
BS = getenv("BS", 9)
input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
# run model twice to get only what changes, these are the kernels of the model
for _ in range(2):
lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
targets = [lm_logits, seq_relationship_logits]
if getenv("BACKWARD"):
optim.zero_grad()
loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
# ignore grad norm and loss scaler for now
loss.backward()
targets += [x for x in optim.schedule_step()]
sched = Tensor.schedule(*targets)
print(f"schedule length {len(sched)}")
return sched
if __name__ == "__main__":
if getenv("HALF", 1):
dtypes.default_float = dtypes.half
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
if getenv("BACKWARD"): Tensor.training = True
print(f"optimizing for {Device.DEFAULT}")
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
sched = [x for x in sched if x.ast.op is Ops.SINK]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
# work with the schedule
total_tm = 0
running_gflops = 0
usage = {}
for i,si in enumerate(sched):
if DEBUG >= 3: print(si.ast)
rawbufs = bufs_from_lin(Kernel(si.ast))
# "linearize" the op into uops in different ways
lins: list[tuple[Kernel, str]] = []
# always try hand coded opt
lin = Kernel(si.ast, opts=device.renderer)
lin.apply_opts(hand_coded_optimizations(lin))
lins.append((lin, "HC"))
# maybe try tensor cores
lin = Kernel(si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append((lin, "TC"))
# try a beam search
if beam:=getenv("BEAM"):
lin = Kernel(si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append((lin, "BEAM"))
# try MCTS
if mcts:=getenv("MCTS"):
lin = Kernel(si.ast, opts=device.renderer)
lin = mcts_search(lin, rawbufs, mcts)
lins.append((lin, "MCTS"))
# benchmark the programs
choices = []
for lin, nm in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
ops = (prg:=lin.to_program()).estimates.ops
gflops = sym_infer(ops, {k:k.min for k in lin.ast.variables()})*1e-9/tm
choices.append((tm, gflops, lin, prg, nm))
sorted_choices = sorted(choices, key=lambda x: x[0])
if DEBUG >= 1: # print all kernels
for tm, gflops, lin, prg, nm in choices:
print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {colored(nm, 'green') if lin is sorted_choices[0][2] else nm}")
tm, gflops, lin, prg, nm = sorted_choices[0]
if getenv("SRC"):
print(si.ast)
print(lin.applied_opts)
print(lin.to_program().src)
total_tm += tm
running_gflops += gflops * tm
if (key := str([str(m) for m in si.metadata])) not in usage: usage[key] = (0, 0)
usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[repr(m) if TRACEMETA >= 2 else str(m) for m in si.metadata]}")
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
print("usage:")
for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
print(f"{usage[k][0]*1000:.2f} ms: {k} ({usage[k][1]} times)")