carrot/tinygrad_repo/test/unit/test_tensor_uop_representation.py
carrot 9c7833faf9
KerryGold Model, AGNOS12.4, AdjustLaneChange, EnglighSound (#182)
* Vegetarian Filet o Fish model

* fix.. atc..

* test cluster_speed_limit

* fix.. cluster_speed_limit.. 2

* fix.. clusterspeedlimit3

* cruise speed to roadlimit speed

* fix..

* fix.. eng

* deltaUp/Down for lanechange

* fix.. atc desire...

* fix..

* ff

* ff

* fix..

* fix.. eng

* fix engsound

* Update desire_helper.py

* fix.. connect...

* fix curve_min speed

* Revert "fix curve_min speed"

This reverts commit fcc9c2eb14eb3504abef3e420db93e8882e56f37.

* Reapply "fix curve_min speed"

This reverts commit 2d2bba476c58a7b4e13bac3c3ad0e4694c95515d.

* fix.. auto speed up.. roadlimit

* fix.. atc auto lanechange...

* Update desire_helper.py

* Update cruise.py

* debug atc...

* fix.. waze alert offset..

* fix..

* test atc..

* fix..

* fix.. atc

* atc test..

* fix.. atc

* fix.. atc2

* fix.. atc3

* KerryGold Model.  latsmooth_sec = 0.0

* lat smooth seconds 0.13

* fix comment

* fix.. auto cruise, and speed unit

* change lanemode switching.

* erase mazda lkas button.
2025-06-22 10:51:42 +09:00

128 lines
4.9 KiB
Python

import unittest
from tinygrad import Tensor
from tinygrad.uop.ops import UPat, Ops, UOp
# NOTE: unlike before base for a realized tensor is always a BUFFER
realized_pattern = UPat(Ops.BUFFER)
# after realization, base tensor uops become RESHAPE(BUFFER)
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
class TestTensorMutates(unittest.TestCase):
def test_mutate_add(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
ret = a+b
pa = a.uop
pb = b.uop
pr = ret.uop
ret.schedule()
self.assertIsNot(pa, a.uop)
self.assertIsNot(pb, b.uop)
self.assertIsNot(pr, ret.uop)
for t in [a,b,ret]: is_pattern_uop(t.uop.base, realized_pattern)
def test_reshape_is_same_parent(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = a+b
d = (a+b).reshape(3,1)
d.realize()
is_pattern_uop(d.uop.base, realized_pattern)
is_pattern_uop(c.uop.base, realized_pattern)
# NOTE: we keep movement ops on top of the buffer view
is_pattern_uop(c.uop, UPat(Ops.BUFFER))
is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,)))
def test_reshape_is_same_child(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = a+b
d = (a+b).reshape(3,1)
c.realize()
is_pattern_uop(c.uop.base, realized_pattern)
is_pattern_uop(d.uop.base, realized_pattern)
class TestTensorUopRepresentation(unittest.TestCase):
def test_realized(self):
a = Tensor([1.,2,3]).realize()
print(a.uop)
is_pattern_uop(a.uop.base, realized_pattern)
def test_add_realized(self):
a = Tensor([1.,2,3]).realize()
b = Tensor([4.,5,6]).realize()
c = a+b
print(c.uop)
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
def test_const_pattern(self):
a = Tensor(1)
print(a.uop)
is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src
is_pattern(a, UPat.cvar("x")) # even cvar works!
def test_consts_do_not_realize(self):
a = Tensor(1)
print(a.uop)
pre_realize = a.uop
a.realize()
assert a.uop is pre_realize
def test_viewed_consts_do_not_realize(self):
a = Tensor.ones(10, 10)
print(a.uop)
a.realize()
is_pattern(a, const_pattern)
self.assertEqual(a.uop.shape, (10, 10))
# currently, CONSTs have a "fake" BUFFER. this should be fixed
# current:
# UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
# UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float, arg=(-1, 'METAL', 1), src=()),
# UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),)),)),))
# expected:
# UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
# UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
# UOp(Ops.DEVICE, dtypes.void, arg="METAL", src=()),)),)),))
def test_consts_dont_have_buffers(self):
a = Tensor.ones(10, 10)
print(a.uop)
buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER]
self.assertEqual(len(buffers_in_parents), 0)
# currently, COPY has an extra BUFFER on the output
# current:
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float, arg=(2, 'TEST', 3), src=()),
# UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float, arg=(1, 'METAL', 3), src=()),)),)),))
# expected:
# UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=(
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
# UOp(Ops.BUFFER, dtypes.float, arg=(1, 'METAL', 3), src=()),))
# update: now the arg is just a single bool, the first source is a device.
def test_copyin(self):
a = Tensor([1.,2,3]).realize()
c = a.to("TEST") # NOTE: this isn't checked
print(c.uop)
is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE))))
def test_empty_buf(self):
a = Tensor.empty(3, 3)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
vi = UOp.variable("i", 1, 3).bind(1)
a = Tensor.empty(3, vi)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
self.assertEqual(a.uop.base.buffer.size, 9)
if __name__ == '__main__':
unittest.main()