carrot/tinygrad_repo/test/test_sample.py
Vehicle Researcher 4fca6dec8e openpilot v0.9.8 release
date: 2025-01-29T09:09:56
master commit: 227bb68e1891619b360b89809e6822d50d34228f
2025-01-29 09:09:58 +00:00

19 lines
616 B
Python

import unittest
import numpy as np
from tinygrad import Tensor, Variable
class TestSample(unittest.TestCase):
def test_sample(self):
X = Tensor.rand(10000, 50).realize()
BS = 16
idxs = np.random.randint(0, X.shape[0], size=(BS))
# this uncovered a bug with arg sort order
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
print(idxs)
ret = x.numpy()
base = X.numpy()[idxs]
np.testing.assert_equal(ret, base)
if __name__ == '__main__':
unittest.main()