carrot/tinygrad_repo/extra/dist/collectives.py
FrogAi 659adb6457 openpilot v0.9.7 release
date: 2024-03-17T10:14:38
master commit: 7e9a909e0e57ecb31df4c87c5b9a06b1204fd034
2024-05-24 17:43:27 -07:00

42 lines
1.5 KiB
Python

from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.dist import world
def allreduce(t:Tensor, cache_id=None) -> Tensor:
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
# flatten
flattened = t.flatten()
# pad to evenly divide
if flattened.shape[0] % WORLD_SIZE != 0:
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
# chunk
chunks = flattened.chunk(WORLD_SIZE, dim=0)
next_rank = (RANK + 1) % WORLD_SIZE
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
# scatter reduce
current_chunk_index = RANK
for i in range(WORLD_SIZE - 1):
world.send(chunks[current_chunk_index], next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
world.recv(recv_buf, prev_rank)
chunks[current_chunk_index] += recv_buf
# gather
current_chunk_index = (RANK + 1) % WORLD_SIZE
for i in range(WORLD_SIZE - 1):
world.send(chunks[current_chunk_index], next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
world.recv(recv_buf, prev_rank)
chunks[current_chunk_index].assign(recv_buf)
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)