
* fix.. speed_limit error... * draw tpms settings. * fix.. traffic light stopping only.. * fix.. waze cam * fix.. waze... * add setting (Enable comma connect ) * auto detect LFA2 * fix.. cruisespeed1 * vff2 driving model. * fix.. * agnos 12.3 * fix.. * ff * ff * test * ff * fix.. drawTurnInfo.. * Update drive_helpers.py * fix.. support eng voice eng sounds fix settings... english fix.. mph.. fix.. roadlimit speed bug.. * new vff model.. 250608 * fix soundd.. * fix safe exit speed.. * fix.. sounds. * fix.. radar timeStep.. * KerryGold model * Update drive_helpers.py * fix.. model. * fix.. * fix.. * Revert "fix.." This reverts commit b09ec459afb855c533d47fd7e8a1a6b1a09466e7. * Revert "fix.." This reverts commit 290bec6b83a4554ca232d531a911edccf94a2156. * fix esim * add more acc table. 10kph * kg update.. * fix cruisebutton mode3 * test atc..cond. * fix.. canfd * fix.. angle control limit
46 lines
1.7 KiB
Python
46 lines
1.7 KiB
Python
import numpy as np
|
|
import unittest
|
|
from tinygrad import Tensor
|
|
from tinygrad.helpers import get_single_element
|
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
|
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
|
|
|
class TestOptGemm(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
N = 64
|
|
cls.a = Tensor.randn(N, N).contiguous().realize()
|
|
cls.b = Tensor.randn(N, N).contiguous().realize()
|
|
cls.res = cls.a.T.numpy() @ cls.b.T.numpy()
|
|
|
|
def _test_gemm_unrolled_permute_l(self, opts=[]):
|
|
t = self.a.T @ self.b.T
|
|
# TODO: this should be a generic test helper
|
|
si = get_single_element(t.schedule())
|
|
k = Kernel(si.ast)
|
|
k.apply_opts(opts)
|
|
run = CompiledRunner(k.to_program())
|
|
ExecItem(run, si.bufs).run()
|
|
test = si.bufs[0].numpy().reshape(self.res.shape)
|
|
np.testing.assert_allclose(self.res, test, atol=1e-4)
|
|
|
|
def test_gemm_unrolled_permute_l_44(self):
|
|
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4)]
|
|
self._test_gemm_unrolled_permute_l(opts)
|
|
|
|
def test_gemm_unrolled_permute_l_424(self):
|
|
# was failing with LLVM
|
|
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
|
self._test_gemm_unrolled_permute_l(opts)
|
|
|
|
def test_gemm_unrolled_permute_l_42(self):
|
|
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
|
|
self._test_gemm_unrolled_permute_l(opts)
|
|
|
|
def test_gemm_unrolled_permute_l_22(self):
|
|
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2)]
|
|
self._test_gemm_unrolled_permute_l(opts)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|