carrot/tinygrad_repo/extra/gemm/metal_matvec.py
carrot 9c7833faf9
KerryGold Model, AGNOS12.4, AdjustLaneChange, EnglighSound (#182)
* Vegetarian Filet o Fish model

* fix.. atc..

* test cluster_speed_limit

* fix.. cluster_speed_limit.. 2

* fix.. clusterspeedlimit3

* cruise speed to roadlimit speed

* fix..

* fix.. eng

* deltaUp/Down for lanechange

* fix.. atc desire...

* fix..

* ff

* ff

* fix..

* fix.. eng

* fix engsound

* Update desire_helper.py

* fix.. connect...

* fix curve_min speed

* Revert "fix curve_min speed"

This reverts commit fcc9c2eb14eb3504abef3e420db93e8882e56f37.

* Reapply "fix curve_min speed"

This reverts commit 2d2bba476c58a7b4e13bac3c3ad0e4694c95515d.

* fix.. auto speed up.. roadlimit

* fix.. atc auto lanechange...

* Update desire_helper.py

* Update cruise.py

* debug atc...

* fix.. waze alert offset..

* fix..

* test atc..

* fix..

* fix.. atc

* atc test..

* fix.. atc

* fix.. atc2

* fix.. atc3

* KerryGold Model.  latsmooth_sec = 0.0

* lat smooth seconds 0.13

* fix comment

* fix.. auto cruise, and speed unit

* change lanemode switching.

* erase mazda lkas button.
2025-06-22 10:51:42 +09:00

113 lines
4.9 KiB
Python

import numpy as np
import time, torch, torch.mps
from tinygrad import Tensor, TinyJit, Device
from tinygrad.helpers import flat_mv
from tinygrad.runtime.ops_metal import MetalAllocator, MetalDevice, MetalProgram, MetalCompiler
N = 16384
M = 4096
FLOPS = N*M*2
nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
def torch_prog(b, c):
st = time.perf_counter()
a = b@c
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
torch_a = (b@c).cpu()
device = MetalDevice("METAL")
metalalloc = MetalAllocator(device)
WORKSIZE_ROW = 16
WORKSIZE_COL = 1
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1]
prog = MetalProgram(device, "test", MetalCompiler().compile(f"""
#include <metal_stdlib>
using namespace metal;
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
int gidx0 = gid.x; /* {GLOBAL_SIZE[0]} */
int lidx1 = lid.x; /* {LOCAL_SIZE[0]} */
int lidx2 = lid.y; /* {LOCAL_SIZE[1]} */
int lidx3 = lid.z; /* {LOCAL_SIZE[2]} */
// 4 rows per thread
threadgroup float4 acc0[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}];
int acc0_index = ((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*lidx3);
acc0[acc0_index] = float4(0.0f,0.0f,0.0f,0.0f);
threadgroup float4 val1[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}];
// iterate over the columns
for (int ridx2 = 0; ridx2 < {N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))}; ++ridx2) {{
// load 4*threadgroup_size columns into shared memory
int col_1 = (((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+(lidx1*{LOCAL_SIZE[1]})+lidx2;
val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+((lidx1*{LOCAL_SIZE[1]})+lidx2)] = *((device float4*)(data1+(col_1*4)));
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int ridx3 = 0; ridx3 < {LOCAL_SIZE[0]*LOCAL_SIZE[1]}; ++ridx3) {{
int col = ((((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+ridx3);
float4 val1_0 = val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+ridx3];
float4 val2_0 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*0})));
float4 val2_1 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*1})));
float4 val2_2 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*2})));
float4 val2_3 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*3})));
acc0[acc0_index] = ((val1_0.x*val2_0)+acc0[acc0_index]);
acc0[acc0_index] = ((val1_0.y*val2_1)+acc0[acc0_index]);
acc0[acc0_index] = ((val1_0.z*val2_2)+acc0[acc0_index]);
acc0[acc0_index] = ((val1_0.w*val2_3)+acc0[acc0_index]);
}}
threadgroup_barrier(mem_flags::mem_threadgroup);
}} /* reduce */
if (lidx3 == 0) {{
float4 out = float4(0.0f,0.0f,0.0f,0.0f);
for (int n = 0; n < {LOCAL_SIZE[2]}; n++) {{
out += acc0[((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*n)];
}}
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
}}
}}
"""))
a = metalalloc.alloc(M*4)
b = metalalloc.alloc(N*4)
c = metalalloc.alloc(N*M*4)
metalalloc._copyin(b,nb.tobytes())
metalalloc._copyin(c,nc.tobytes())
def metalrun():
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
return a
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(metalrun) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
metal_a = np.zeros(M, dtype=np.float32)
metalalloc._copyout(flat_mv(metal_a.data), a)
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return (b@c).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
Device["METAL"].synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
tiny_a = tiny_jit(b, c).numpy()
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)