carrot/tinygrad_repo/extra/gemm/torch_gemm.py
FrogAi 659adb6457 openpilot v0.9.7 release
date: 2024-03-17T10:14:38
master commit: 7e9a909e0e57ecb31df4c87c5b9a06b1204fd034
2024-05-24 17:43:27 -07:00

18 lines
536 B
Python

import time
import torch
for dtype in [torch.float16, torch.float32]:
for N in [256, 512, 1024, 2048, 4096]:
FLOPS = N*N*N*2
b = torch.rand((N,N), dtype=dtype).cuda()
c = torch.rand((N,N), dtype=dtype).cuda()
def torch_prog(b, c):
st = time.perf_counter()
a = b@c
torch.cuda.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")