from typing import cast, Optional, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal from collections import defaultdict from dataclasses import replace from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import ImageDType, PtrDType from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer import ProgramSpec actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)] actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)] actions += [Opt(op=OptOps.TC, axis=0, amt=0)] actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals): test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1 while prod(test_global_size) > max_global_size: for j in range(len(global_size)-1,-1,-1): if test_global_size[j] > 16: test_global_size[j] //= 2 factor *= 2 break return test_global_size, factor def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None, max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]: factor = 1 if p.global_size is not None and max_global_size is not None: global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals) p = replace(p, global_size=global_size) try: car = CompiledRunner(p, precompiled=lib) except AssertionError: return [math.inf] * cnt tms = [] input_bufs = [rawbufs[i] for i in car.p.globals] for _ in range(cnt): if clear_l2: if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches() else: with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False) tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor) if early_stop is not None and early_stop < min(tms): break return tms class TimeoutException(Exception): pass def timeout_handler(signum, frame): raise TimeoutException() def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]: if hasattr(signal, "alarm"): signal.signal(getattr(signal, 'SIGALRM'), timeout_handler) # set timeout signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10)) ret = None try: p = x[1].to_program(name_override="test") assert p.uops is not None, "uop list wasn't generated?" if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops") st = time.perf_counter() prog = compiler.compile(p.src) et = time.perf_counter() - st ret = (p, prog, et) except RuntimeError: if DEBUG >= 4: traceback.print_exc() except Exception as e: if getenv("BEAM_STRICT_MODE"): raise e finally: if hasattr(signal, "alarm"): signal.alarm(0) return x[0], ret # workers should ignore ctrl c def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs] # *** external API *** # get (scrap) buffers for timing the linearizer def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: bufsts: defaultdict[int, list[UOp]] = defaultdict(list) for x in lin.bufs: if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x) rawbufs: list[Optional[Buffer]] = [None]*len(bufsts) for k,lx in bufsts.items(): buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx) assert isinstance(dtype, (PtrDType, ImageDType)) if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case. buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype) assert all(r is not None for r in rawbufs) return cast(list[Buffer], rawbufs) # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) for i,a in enumerate(actions): if a.axis is not None and a.op is not OptOps.TC: if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue lin2 = lin.copy() try: lin2.apply_opt(a) up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1 for s,c in zip(lin2.full_shape, lin2.colors()): if c in {"magenta", "yellow"}: up *= s elif c in {"cyan", "green", "white"}: lcl *= s if up//tc_up > max_up or lcl > max_lcl: continue acted_lins[i+1] = lin2 except KernelOptError: pass return acted_lins beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel: global beam_pool key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None: ret = lin.copy() for o in val[len(lin.applied_opts):]: ret.apply_opt(o) return ret beam: list[tuple[Kernel, float]] = [(lin, float("inf"))] seen_libs = set() default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0 if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)): beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16)) min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6 if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}") if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}") try: rawbufs = _ensure_buffer_alloc(rawbufs) var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} exiting, st = False, time.perf_counter() dev = Device[lin.opts.device] while not exiting: acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam]) timed_lins: list[tuple[Kernel, float]] = [] _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler) least_compute_ops = math.inf for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))): if proc is None: continue p, lib, compile_et = proc if lib in seen_libs: continue # filter out kernels that use 1000x more compute than the smallest least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops) if least_compute_ops*1000 < this_compute_ops: continue seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches')) except RuntimeError: continue # for runtime issues timed_lins.append((acted_lins[i], min(tms))) if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501 elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 # done opts = sorted(timed_lins, key=lambda x: x[1]) exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress)) if not exiting: beam = opts[:amt] elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1] if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501 except KeyboardInterrupt as e: if beam_pool is not None: beam_pool.terminate() raise e if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}") return beam[0][0] def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]: test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs MAX_WORKGROUP = 1024 local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice def try_exec(local_size): try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501 except Exception: return float('inf') ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) assert not math.isinf(ret[0]), "all optimize_local_size exec failed" return ret[1] def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) dev = Device[lin.opts.device] assert dev.compiler is not None rawbufs = _ensure_buffer_alloc(rawbufs) var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} p = lin.to_program() tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) return min(tms)