45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
![]() |
#!/usr/bin/env python
|
||
|
import unittest
|
||
|
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
|
||
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||
|
from tinygrad.helpers import dtypes
|
||
|
|
||
|
class TestFlopCounter(unittest.TestCase):
|
||
|
def setUp(self):
|
||
|
self.buf0 = LazyOp(BufferOps.MEM, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||
|
self.buf1 = LazyOp(BufferOps.MEM, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||
|
|
||
|
def test_flops_add(self):
|
||
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||
|
info = get_lazyop_info(op0)
|
||
|
self.assertEqual(info.flops, 4)
|
||
|
|
||
|
def test_flops_add_twice(self):
|
||
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||
|
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||
|
info = get_lazyop_info(op1)
|
||
|
self.assertEqual(info.flops, 8)
|
||
|
|
||
|
def test_flops_add_self(self):
|
||
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||
|
op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None)
|
||
|
info = get_lazyop_info(op1)
|
||
|
self.assertEqual(info.flops, 8)
|
||
|
|
||
|
def test_flops_add_roundabout_self(self):
|
||
|
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||
|
op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None)
|
||
|
op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None)
|
||
|
info = get_lazyop_info(op2)
|
||
|
self.assertEqual(info.flops, 12)
|
||
|
|
||
|
def test_flops_red(self):
|
||
|
op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None)
|
||
|
op1 = LazyOp(ReduceOps.SUM, (op0,), (1,))
|
||
|
op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None)
|
||
|
info = get_lazyop_info(op2)
|
||
|
self.assertEqual(info.flops, 9)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|