Vehicle Researcher 8eb8330d95 openpilot v0.9.9 release
date: 2025-03-08T09:09:29
master commit: ce355250be726f9bc8f0ac165a6cde41586a983d
2025-03-08 09:09:31 +00:00

218 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Callable, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}
# ** API spec
@dataclass
class GraphRewriteMetadata:
"""Overview of a tracked rewrite to viz the sidebar"""
loc: tuple[str, int]
"""File_path, Lineno"""
code_line: str
"""The Python line calling graph_rewrite"""
kernel_name: str
"""The kernel calling graph_rewrite"""
upats: list[tuple[tuple[str, int], str, float]]
"""List of all the applied UPats"""
@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
graphs: list[UOp]
"""Sink at every step of graph_rewrite"""
diffs: list[list[str]]
""".diff style before and after of the rewritten UOp child"""
changed_nodes: list[list[int]]
"""Nodes that changed at every step of graph_rewrite"""
kernel_code: Optional[str]
"""The program after all rewrites"""
# ** API functions
# NOTE: if any extra rendering in VIZ fails, we don't crash
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
try: return fxn(*args, **kwargs)
except Exception as e: return f"ERROR: {e}"
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]:
kernels: dict[str, list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {}
for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"):
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
for ctx in ctxs:
if pickle.loads(ctx.sink).op is Ops.CONST: continue
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())
def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
assert isinstance(x, UOp)
graph: dict[int, tuple[str, str, list[int], str, str]] = {}
excluded = set()
for u in x.toposort:
if u.op in {Ops.CONST, Ops.DEVICE}:
excluded.add(u)
continue
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f" / {v.offset}" if v.offset else "") for v in u.arg.views])) if u.op is Ops.VIEW else str(u.arg)
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}"
if x.op is Ops.DEVICE: label += f"\nDEVICE{idx} {x.arg}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x not in excluded], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
def _replace_uop(base:UOp, replaces:dict[UOp, UOp]) -> UOp:
if (found:=replaces.get(base)) is not None: return found
ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src))
if (final := replaces.get(ret)) is not None:
return final
replaces[base] = ret
return ret
@functools.lru_cache(None)
def _prg(k:Kernel): return k.to_program().src
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None)
replaces: dict[UOp, UOp] = {}
sink = g.graphs[0]
for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches):
u0 = pickle.loads(u0_b)
# if the match didn't result in a rewrite we move forward
if u1_b is None:
replaces[u0] = u0
continue
replaces[u0] = u1 = pickle.loads(u1_b)
# first, rewrite this UOp with the current rewrite + all the matches in replaces
new_sink = _replace_uop(sink, {**replaces})
# sanity check
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST])
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
g.graphs.append(sink:=new_sink)
return g
# Profiler API
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
ret = []
for i,e in enumerate(ev.ents):
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
for dep in ev.deps[i]:
d = ev.ents[dep]
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
return ret
def to_perfetto(profile:list[ProfileEvent]):
# Start json with devices.
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
for ev in tqdm(profile, desc="preparing profile"):
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
# ** HTTP server
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
if (url:=urlparse(self.path)).path == "/":
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif (url:=urlparse(self.path)).path == "/profiler":
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
elif self.path.startswith("/assets/") and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"
if url.path.endswith(".css"): content_type = "text/css"
except FileNotFoundError: status_code = 404
elif url.path == "/kernels":
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]}
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
ret, content_type = json.dumps(jret).encode(), "application/json"
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
else: status_code = 404
# send response
self.send_response(status_code)
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
# ** main loop
def reloader():
mtime = os.stat(__file__).st_mtime
while not stop_reloader.is_set():
if mtime != os.stat(__file__).st_mtime:
print("reloading server...")
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
def load_pickle(path:str):
if path is None or not os.path.exists(path): return None
with open(path, "rb") as f: return pickle.load(f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
parser.add_argument('--profile', type=str, help='Path profile', default=None)
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
stop_reloader = threading.Event()
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
print("*** viz is starting")
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
kernels = get_metadata(*contexts) if contexts is not None else []
if getenv("FUZZ_VIZ"):
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
print(f"fuzzed {len(ret)} rewrite details")
perfetto_profile = to_perfetto(profile) if profile is not None else None
server = HTTPServer(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(f"*** started viz on {HOST}:{PORT}")
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()