carrot/tinygrad_repo/test/unit/test_upat_compile.py

54 lines
1.4 KiB
Python
Raw Permalink Normal View History

import unittest
from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UPat, track_rewrites, GroupOp, Ops
from tinygrad.uop.upat import _get_code, upat_compile
import dis
@track_rewrites()
def do_compile(up):
print("\n***** COMPILE", up)
match_code = _get_code(up, False)
match = upat_compile(up, lambda **kwargs: None)
print(match_code[0])
if DEBUG >= 2: dis.dis(match)
return match_code[0]
class TestUPatCompile(unittest.TestCase):
def test_double(self):
up = UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1")
do_compile(up)
def test_single(self):
up = UPat.var("x") + UPat.var("y")
do_compile(up)
def test_xpx(self):
up = UPat.var("x") + UPat.var("x")
do_compile(up)
def test_xp0(self):
up = UPat.var("x") + 0
do_compile(up)
def test_bool(self):
up = UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool)
do_compile(up)
def test_single_c(self):
up = (UPat.var("x") + UPat.var("y")) * UPat.var("c")
do_compile(up)
def test_const_folding(self):
up = UPat(GroupOp.ALU-{Ops.THREEFRY}, name="a", src=UPat((Ops.VCONST, Ops.CONST)))
do_compile(up)
@unittest.skip("fix this")
def test_range_named(self):
# this should be one src, but this should also still work
up = UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16))
do_compile(up)
if __name__ == "__main__":
unittest.main()