carrot/tinygrad_repo/test/test_viz.py
Vehicle Researcher 8eb8330d95 openpilot v0.9.9 release
date: 2025-03-08T09:09:29
master commit: ce355250be726f9bc8f0ac165a6cde41586a983d
2025-03-08 09:09:31 +00:00

216 lines
9.8 KiB
Python

from typing import Dict, List, Optional
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
@track_rewrites(named=True)
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
rewrite(sink, pm, **kwargs)
assert len(contexts) == 1
assert len(contexts[0]) == 1
k = get_metadata(keys, contexts)[0][0]
g = get_details(*k)
return g.graphs[1:]
class TestViz(unittest.TestCase):
def setUp(self):
contexts.clear()
keys.clear()
self.tms = TRACK_MATCH_STATS.value
TRACK_MATCH_STATS.value = 2
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
def test_viz_simple(self):
pm = PatternMatcher([
(UPat.var("x")*1, lambda x:x),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm)
self.assertEqual(len(uops), 1)
self.assertEqual(uops[0], a)
def test_rewrite_twice(self):
pm = PatternMatcher([
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)
self.assertEqual(len(uops), 2)
self.assertEqual(uops[0], a*2)
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
def test_rewrite_with_ctx(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
def store_load(ctx:Dict[UOp, None], x:UOp) -> Optional[UOp]:
if x in ctx: return None
ctx[x] = None
return UOp.store(*x.src, x)
pm = PatternMatcher([
(UPat(Ops.LOAD, name="x"), store_load),
])
uops = helper_test_viz(a+b, pm, ctx={})
self.assertEqual(len(uops), 2)
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
def test_track_rewrites(self):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites(named=True)
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
do_rewrite(ld*1)
do_rewrite(ld*2)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 2)
key, _, m = ret[0][0]
self.assertEqual(key, "do_rewrite_1")
self.assertEqual(len(m.upats), 1)
key, _, m = ret[1][0]
self.assertEqual(key, "do_rewrite_2")
self.assertEqual(len(m.upats), 0)
def test_track_rewrites_with_exception(self):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites()
def do_rewrite(x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this
raise Exception("test")
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
with self.assertRaises(Exception): do_rewrite(ld*1)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 1)
def test_fold_const(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
graph = uop_to_json(a)
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
@unittest.skip("TODO: bring this back with better testing")
def test_bottom_up_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
self.assertEqual(len(ret), 2)
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
def test_top_down_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# if it wasn't bottom_up, it's rewritten once
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
self.assertEqual(len(ret), 1)
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
def test_rewrite_without_context(self):
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
untracked_graph_rewrite(add)
self.assertEqual(len(contexts), 0)
tracked_graph_rewrite(add)
self.assertEqual(len(contexts), 1)
def test_inner_rewrite_location(self):
# inner rewrite gets tracked in another context
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
tracked_graph_rewrite(add)
self.assertEqual(len(contexts), 1)
# location of context is inner_rewrite
fp, lineno = contexts[0][0].loc
self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno)
self.assertEqual(fp, inner_rewrite.__code__.co_filename)
class TextVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][1]['tid'], 0)
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][2]['tid'], 1)
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
self.assertEqual(j['traceEvents'][3]['ts'], 0)
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][3]['tid'], 0)
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['tid'], 1)
def test_perfetto_graph(self):
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
self.assertEqual(j['traceEvents'][6]['ts'], 0)
self.assertEqual(j['traceEvents'][6]['dur'], 2)
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
self.assertEqual(j['traceEvents'][7]['ts'], 954)
self.assertEqual(j['traceEvents'][7]['dur'], 4)
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()