484 lines
26 KiB
Python
Raw Normal View History

2025-04-18 20:38:55 +09:00
import sys, atexit, pickle
from collections import defaultdict, deque
2025-04-18 20:38:55 +09:00
from dataclasses import dataclass
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.device import Buffer
2025-04-18 20:38:55 +09:00
from tinygrad.spec import type_verify, kernel_spec
# creation can recurse a lot
sys.setrecursionlimit(10000)
2025-04-18 20:38:55 +09:00
# **** schedule simplifier
def simplify_stride0_reduce(reduce:UOp, x:UOp):
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
if any(v.mask is not None for v in unwrap(x.st).views): return None
# must have all stride 0 in the relevant axis (NOTE: can do partial)
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
prshape = prod(x.shape[i] for i in reduce.arg[1])
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
match reduce.arg[0]:
case Ops.ADD: return ret*prshape
case Ops.MUL: return ret.pow(prshape)
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
def split_reduceop(reduce:UOp, x:UOp):
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
dim_to_split, divisor = split_candidates[0]
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
# reduce original axes, then split
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce on stride 0 is collapsed
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# make things that can't be images not images
2025-04-18 20:38:55 +09:00
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer/copy is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"),
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
])
2025-04-18 20:38:55 +09:00
# support for using a contiguous permuted view instead of the parent view if one exists
2025-04-18 20:38:55 +09:00
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
2025-04-18 20:38:55 +09:00
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
2025-04-18 20:38:55 +09:00
# reorder view
reorder_view = PatternMatcher([
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize
and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)),
# put UnaryOps before EXPANDs
(UPat(GroupOp.Unary, src=UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
# put CAST after expanding BUFFER
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
])
2025-04-18 20:38:55 +09:00
# **** UOp realization
DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY}
@dataclass(frozen=True)
2025-04-18 20:38:55 +09:00
class GrouperContext:
assigns: dict[UOp, UOp] # maps realized buffers to assigns
realizes: dict[UOp, None] # all the simplified tensor uops we realize
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
2025-04-18 20:38:55 +09:00
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
2025-04-18 20:38:55 +09:00
def realize_before_view(ctx:GrouperContext, view:UOp, tr:UOp) -> None:
st = unwrap(view.st)
# awlays realize unsafe pad ops before masked view
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx.realizes, cache=dict()): return realize(ctx, tr)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
# realize before expand
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
2025-04-18 20:38:55 +09:00
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x, None) for x in s.src if x.op not in DONT_PUSH_VIEWS)),
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="tr"),)), realize_before_view),
# realize before COPY
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-DONT_PUSH_VIEWS, name="tr"))), realize),
])
2025-04-18 20:38:55 +09:00
def append_uop(ctx:GrouperContext, u:UOp) -> None:
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
for s in u.src: ctx.children[s.base][u] = None
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
2025-04-18 20:38:55 +09:00
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
2025-04-18 20:38:55 +09:00
rsize = unwrap(r.st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
2025-04-18 20:38:55 +09:00
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
2025-04-18 20:38:55 +09:00
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
def group_realizes(sink:UOp) -> dict[UOp, None]:
# start by adding uops that always realize
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
if DONT_GROUP_REDUCES: return ctx.realizes
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
2025-04-18 20:38:55 +09:00
for r in sink.toposort:
if r.op is not Ops.REDUCE_AXIS: continue
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
if r in ctx.realizes: continue
group: dict[UOp, None] = {}
2025-04-18 20:38:55 +09:00
recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
2025-04-18 20:38:55 +09:00
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
2025-04-18 20:38:55 +09:00
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
2025-04-18 20:38:55 +09:00
p = parents.pop().base
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
2025-04-18 20:38:55 +09:00
parents.extend(p.src)
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
2025-04-18 20:38:55 +09:00
st = unwrap(tr.st)
while len(ctx.children[tr]) == 1:
2025-04-18 20:38:55 +09:00
tr_next = next(iter(ctx.children[tr]))
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
2025-04-18 20:38:55 +09:00
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
2025-04-18 20:38:55 +09:00
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
tr = tr.src[0].base
group = {tr: None}
2025-04-18 20:38:55 +09:00
ctx.realizes[tr] = None
reduce_for_op.update((tr, r) for tr in group)
2025-04-18 20:38:55 +09:00
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
# maybe fuse arange with its children
if len(flatten(ctx.children[tr] for tr in group)) != 0:
for tr in group: del ctx.realizes[tr]
# fuse double reduces with no other child
for reduceop in double_reduces:
2025-04-18 20:38:55 +09:00
top_reduce = reduceop.src[0].base
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
2025-04-18 20:38:55 +09:00
return ctx.realizes
# **** create kernels
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
def __repr__(self):
return f"<Kernel {len(list(self.ast.toposort))} {[s.op for s in self.ast.src] if self.ast.op is Ops.SINK else self.ast.op} {self.metadata}>"
2025-04-18 20:38:55 +09:00
@dataclass(frozen=True)
class KernelContext:
realizes: dict[UOp, None]
ops_metadata: dict[UOp, Metadata]
def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.ops_metadata.get(x)) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
for s in x.src:
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
else:
new_srcs.extend(s.src)
if (m:=ctx.ops_metadata.get(s)) is not None: metadata[m] = None
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
create_kernels = PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
# create a buffer for COPY on the new device
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="d"), UPat()), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
# otherwise check the context if we're realizing this UOp
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove downstream reshapes from SINK
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
])
2025-04-18 20:38:55 +09:00
# **** swizzler
2025-04-18 20:38:55 +09:00
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
2025-04-18 20:38:55 +09:00
def swizzle_reduceop(r:UOp, src:UOp, view:UOp):
if (st:=unwrap(view.st)).contiguous: return None
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# create a new reduceop for the swizzled input
new_input_st = tmp + ShapeTracker(tuple(nv))
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return UOp(Ops.REDUCE_AXIS, r.dtype, (apply_swizzle(src.view(src.arg+new_input_st if src.op is Ops.VIEW else new_input_st)),),
(r.arg[0], new_axis)).view(ShapeTracker.from_shape(st.shape))
2025-04-18 20:38:55 +09:00
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(x.arg+new_st) if x.op is Ops.VIEW else x.view(new_st)) for x in root.src]
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
# push VIEW to children
view_right = merge_views+PatternMatcher([
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None),
])
2025-04-18 20:38:55 +09:00
# **** unbind variables
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx[0].update(var_vals)
return x.replace(arg=st) if st != x.st else None
def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, val:UOp):
ctx[0][var.replace(src=())] = val.arg
return var
# **** fix kernel AST
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x),
# partial assign can store to a non-contiguous ShapeTracker
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.src[0].base.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# VALID
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)),
# if the last child is a VIEW we merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
])
2025-04-18 20:38:55 +09:00
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
(UPat(Ops.BIND, src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
])
2025-04-18 20:38:55 +09:00
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
# substitute kernel sources for the target buffer + apply reshapes
parents_rep: dict[UOp, UOp] = {}
for s in k.src:
if s.op is Ops.ASSIGN:
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
ast = k.arg.ast.substitute(parents_rep)
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left), view_right)
# add buffer ops + fix_kernel_ops
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# create subbuffer (TODO: this does not belong here)
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
return k.replace(arg=Kernel(ast, k.arg.metadata))
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@atexit.register
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
2025-04-18 20:38:55 +09:00
# **** schedule creation and toposort
2025-04-18 20:38:55 +09:00
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
2025-04-18 20:38:55 +09:00
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
# merge_views + sym + reorder_view + replace_contiguous
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous, ctx={})
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
# get realizes
sink = tensor_map[big_sink]
realize_map = group_realizes(sink)
# map tensor metadata to simplified ops
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
# merge_views + create_kernels
kernel_map = graph_rewrite_map(sink, merge_views+create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
sched_sink = kernel_map[sink]
type_verify(list(sched_sink.toposort), kernel_spec)
# map tensors to buffer/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# ASSIGN always becomes the target buffer
if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0]
# if we created a new buffer for this tensor, map it to the assigned buffer
elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st))
# tensors can also simplify to an existing buffer/const
else:
if k is v: continue
if v.base.op is Ops.BUFFER: becomes_map[k] = v
if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in sched_sink.toposort:
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep:
sched_sink = sched_sink.substitute(assign_rep)
type_verify(list(sched_sink.toposort), kernel_spec)
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# final toposort (bfs)
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort:
if u.op is not Ops.ASSIGN: continue
in_degree[u] = 0
for s in u.src[1].src:
if s.op is not Ops.ASSIGN: continue
children.setdefault(s, []).append(u)
in_degree[u] += 1
queue = deque(k for k,v in in_degree.items() if v == 0)
schedule: list[ScheduleItem] = []
2025-04-18 20:38:55 +09:00
var_vals: dict[Variable, int] = {}
while queue:
2025-04-18 20:38:55 +09:00
u = queue.popleft()
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
2025-04-18 20:38:55 +09:00
# confirm everything was scheduled correctly
2025-04-18 20:38:55 +09:00
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
2025-04-18 20:38:55 +09:00
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
return schedule, var_vals, becomes_map