carrot/tinygrad_repo/test/external/external_jit_failure.py
Vehicle Researcher 8eb8330d95 openpilot v0.9.9 release
date: 2025-03-08T09:09:29
master commit: ce355250be726f9bc8f0ac165a6cde41586a983d
2025-03-08 09:09:31 +00:00

18 lines
414 B
Python

from tinygrad import Tensor, TinyJit, Device
import numpy as np
GPUS = 4
N = 128
ds = tuple([Device.canonicalize(f"{Device.DEFAULT}:{i}") for i in range(GPUS)])
t = Tensor.rand(N, N, N).shard(ds, 0)
n = t.numpy()
@TinyJit
def allreduce(t:Tensor) -> Tensor:
return t.sum(0) #.realize()
for i in range(10):
print(i)
tn = allreduce(t).numpy()
np.testing.assert_allclose(tn, n.sum(0), atol=1e-4, rtol=1e-4)