69 lines
2.9 KiB
Python
Raw Normal View History

import itertools
from collections import defaultdict
import numpy as np
from dataclasses import replace
from typing import DefaultDict, Dict, List, Tuple
from tinygrad.ops import UOp, print_uops, Ops
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import DEBUG, colored
from tinygrad.ops import Variable
from tinygrad.tensor import _to_np_dtype
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
def fuzz_uops(uops:List[UOp]) -> List[Tuple[UOp, ...]]:
blocks: List[List[UOp]] = [[]]
for u in uops:
if u.op in END_FOR_UOP: blocks.append([u])
elif u.op in {x[1] for x in END_FOR_UOP.values()}: blocks.extend([[u], []])
else: blocks[-1].append(u)
paths_for_block: Dict[int, List[Tuple[UOp, ...]]] = {}
for bi, bb in enumerate(blocks):
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
in_degree: Dict[UOp, int] = {}
for u in bb:
in_degree[u] = 0
for x in u.src:
if x in bb:
children[x].append(u)
in_degree[u] += 1
paths_for_block[bi] = find_all_toposorts(children, in_degree)
paths: Dict[Tuple[UOp, ...], None] = {}
for up in itertools.product(*paths_for_block.values()):
paths[tuple(uop for path in up for uop in path)] = None
if len(paths) >= FUZZ_SCHEDULE_MAX_PATHS: break
return list(paths)
class UOpsFuzzerRunner(CompiledRunner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
assert self.p.uops is not None
fuzz_paths = fuzz_uops(self.p.uops)
init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
init_globals = dict(zip(self.p.globals, rawbufs))
if DEBUG >= 1: print(colored(f"fuzzing {len(fuzz_paths)} uop permutations for {init_name}", "yellow"))
super().__call__(rawbufs, var_vals, wait)
ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}
for i, path in enumerate(fuzz_paths):
# setup prg
uops = list(path)
if DEBUG >= 5: print_uops(uops)
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(name, uops), uops=uops)
if DEBUG >= 4: print(self.p.src)
self.lib = Device[self.p.device].compiler.compile_cached(self.p.src)
self.clprg = Device[self.p.device].runtime(name, self.lib)
for x in (rawbufs:=[init_globals[i] for i in self.p.globals]): x.copyin(init_rawbufs[x])
# verify
super().__call__(rawbufs, var_vals, wait)
for i, x in enumerate(rawbufs):
try:
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)), ground_truth[x], atol=1e-6, rtol=1e-6)
if DEBUG >= 2: print(colored(name, "green"))
except AssertionError as e:
print(colored(name, "red"))
raise e