carrot/tinygrad_repo/test/unit/test_attention.py
2025-04-19 08:05:49 +09:00

20 lines
774 B
Python

import unittest
from tinygrad import Tensor, dtypes
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self):
BS, seqlen, dim = 10, 4, 100
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
attn = q.scaled_dot_product_attention(k, v)
sched = attn.schedule()
# attention has 5 kernels now
self.assertEqual(len(sched), 5)
softmax_inputs = sched[1:4]
for si in softmax_inputs:
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
if __name__ == '__main__':
unittest.main()