carrot/tinygrad_repo/test/test_masked_st.py
Vehicle Researcher 8eb8330d95 openpilot v0.9.9 release
date: 2025-03-08T09:09:29
master commit: ce355250be726f9bc8f0ac165a6cde41586a983d
2025-03-08 09:09:31 +00:00

33 lines
912 B
Python

import unittest
from tinygrad.tensor import Tensor
class TestMaskedShapeTracker(unittest.TestCase):
def test_mul_masked(self):
a = Tensor([1,1,1,1,1])
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.lazydata.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_mul_both_masked(self):
a = Tensor([1,1]).pad(((0,3),))
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.lazydata.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_add_masked(self):
a = Tensor([1,1]).pad(((0,2),))
b = Tensor([1,1]).pad(((0,2),))
c = a+b
#assert c.lazydata.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
if __name__ == '__main__':
unittest.main()