51 lines
3.2 KiB
Python
Raw Normal View History

from collections import defaultdict
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
from tinygrad.ops import Ops
# **************** memory planning ****************
def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance = {}, {}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]]
def find_replace_buffer(buf, st, en):
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
else: assigned[buf] = assigned.get(buf, buf)
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]