91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
# stuff needed to unpack a kernel
|
|
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
|
from tinygrad.helpers import dtypes
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.shape.view import View
|
|
from tinygrad.shape.symbolic import Variable
|
|
inf, nan = float('inf'), float('nan')
|
|
|
|
# kernel unpacker
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
|
|
|
|
# load worlds, a dataset of about 12k kernels
|
|
import gzip
|
|
from pathlib import Path
|
|
import random
|
|
from tinygrad.helpers import dedup
|
|
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
|
|
fn = Path(__file__).parent.parent / "datasets/sops.gz"
|
|
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
|
|
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
|
|
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
|
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
|
|
random.seed(1337)
|
|
random.shuffle(ast_strs)
|
|
return ast_strs
|
|
|
|
def assert_same_lin(l1, l2):
|
|
assert l1.colored_shape() == l2.colored_shape()
|
|
assert all(x==y for x,y in zip(l1.sts, l2.sts))
|
|
|
|
# get features
|
|
import math
|
|
from tinygrad.shape.symbolic import Node
|
|
|
|
MAX_DIMS = 16
|
|
MAX_BUFS = 9
|
|
def lin_to_feats(lin:Linearizer, use_sts=True):
|
|
assert lin.shape_len < MAX_DIMS, "too many dims"
|
|
|
|
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]
|
|
lc = [all_colors.index(x) for x in lin.colors()]
|
|
|
|
ret = []
|
|
# before, some generic linearizer stuff
|
|
ret.append(lin.upcasted)
|
|
ret.append(lin.local_dims)
|
|
|
|
# first, the full shape, including the colors
|
|
for s,os,c in zip(lin.full_shape,lin.output_shape,lc):
|
|
if isinstance(s, Node):
|
|
ret.append(False)
|
|
ret += [0]*9
|
|
else:
|
|
ret.append(True)
|
|
ret.append(math.log2(s))
|
|
ret.append(min(33, s))
|
|
ret.append(math.log2(os))
|
|
ret.append(min(33, os))
|
|
ret.append(s%2 == 0)
|
|
ret.append(s%3 == 0)
|
|
ret.append(s%4 == 0)
|
|
ret.append(s%8 == 0)
|
|
ret.append(s%16 == 0)
|
|
cc = [0]*7
|
|
cc[c] = 1
|
|
ret += cc
|
|
ret += [0] * (17*(MAX_DIMS-len(lin.full_shape)))
|
|
ret = [float(x) for x in ret]
|
|
|
|
if use_sts:
|
|
my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts])
|
|
assert len(my_sts) < MAX_BUFS
|
|
sts_len = 3 + 5*MAX_DIMS
|
|
for s in my_sts:
|
|
ret.append(s[0]) # reduce
|
|
ret.append(s[2]) # has mask
|
|
ret.append(s[3]) # len views
|
|
for d in s[1]:
|
|
ret.append(d is None)
|
|
ret.append(d == 0)
|
|
ret.append(d == 1)
|
|
ret.append(min(33, d) if d is not None else -1)
|
|
if d is not None and d >= 1: ret.append(math.log2(d))
|
|
else: ret.append(-1)
|
|
ret += [0] * (5*(MAX_DIMS - len(s[1])))
|
|
ret += [0] * (sts_len*(MAX_BUFS - len(my_sts)))
|
|
assert len(ret) == 1021, f"wrong len {len(ret)}"
|
|
else:
|
|
assert len(ret) == 274, f"wrong len {len(ret)}"
|
|
return ret |