42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import getenv
|
|
|
|
from extra.dist import world
|
|
|
|
def allreduce(t:Tensor, cache_id=None) -> Tensor:
|
|
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
|
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
|
|
|
|
# flatten
|
|
flattened = t.flatten()
|
|
|
|
# pad to evenly divide
|
|
if flattened.shape[0] % WORLD_SIZE != 0:
|
|
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
|
|
|
# chunk
|
|
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
|
|
|
next_rank = (RANK + 1) % WORLD_SIZE
|
|
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
|
|
# scatter reduce
|
|
current_chunk_index = RANK
|
|
for i in range(WORLD_SIZE - 1):
|
|
world.send(chunks[current_chunk_index], next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
|
|
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
|
world.recv(recv_buf, prev_rank)
|
|
chunks[current_chunk_index] += recv_buf
|
|
|
|
# gather
|
|
current_chunk_index = (RANK + 1) % WORLD_SIZE
|
|
for i in range(WORLD_SIZE - 1):
|
|
world.send(chunks[current_chunk_index], next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
|
|
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
|
world.recv(recv_buf, prev_rank)
|
|
chunks[current_chunk_index].assign(recv_buf)
|
|
|
|
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)
|