208 lines
13 KiB
Python
Raw Normal View History

from __future__ import annotations
import os, pathlib, struct, ctypes, tempfile, functools, decimal
from typing import Any, Union, cast
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
from tinygrad.renderer.cstyle import MetalRenderer
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
def __hash__(self): return hash(self.value)
def __eq__(self, other): return self.value == other.value
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
def __del__(self): msg(self, "release")
@functools.lru_cache(None)
def sel(name: str): return libobjc.sel_registerName(name.encode())
class MTLResourceOptions:
MTLResourceCPUCacheModeDefaultCache = 0
MTLResourceStorageModeShared = 0 << 4
class MTLPipelineOption:
MTLPipelineOptionNone = 0
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
libobjc.objc_getClass.restype = objc_id
libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
libdispatch.dispatch_data_create.restype = objc_instance
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
sender.restype = restype
return sender(ptr, sel(selector), *args)
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
def from_ns_str(s): return bytes(msg(s, "UTF8String", restype=ctypes.c_char_p)).decode()
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
def wait_check(cbuf: Any):
msg(cbuf, "waitUntilCompleted")
error_check(msg(cbuf, "error", restype=objc_instance))
def cmdbuf_label(cbuf: objc_id) -> str: return from_ns_str(msg(cbuf, "label", restype=objc_id))
def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double))
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
if error.value is None: return None
raise error_constructor(from_ns_str(msg(error, "localizedDescription", restype=objc_instance)))
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
library = msg(device.sysdevice, "newLibraryWithSource:options:error:", to_ns_str(src), options,
ctypes.byref(compileError:=objc_instance()), restype=objc_instance)
error_check(compileError, CompileError)
return library
class MetalCompiler(Compiler):
def __init__(self):
self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
super().__init__("compile_metal_direct")
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
def compile(self, src:str) -> bytes:
ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
def callback(blockptr, error, dataPtr, dataLen, errorMessage):
nonlocal ret
if error == 0:
reply = bytes(to_mv(dataPtr, dataLen))
# offset from beginning to data = header size + warning size
ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
else:
ret = CompileError(errorMessage.decode())
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}"'
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
# argument and pretend it's a normal callback
compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
if isinstance(ret, Exception): raise ret
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
return ret
def disassemble(self, lib:bytes):
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
if ret: print("Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
class MetalProgram:
def __init__(self, dev:MetalDevice, name:str, lib:bytes):
self.dev, self.name, self.lib = dev, name, lib
if lib[:4] == b"MTLB":
# binary metal library
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
error_check(error_lib)
else:
# metal source. rely on OS caching
try: self.library = metal_src_to_library(self.dev, lib.decode())
except CompileError as e: raise RuntimeError from e
self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
msg(descriptor, "setComputeFunction:", self.fxn)
msg(descriptor, "setSupportIndirectCommandBuffers:", True)
self.pipeline_state = msg(self.dev.sysdevice, "newComputePipelineStateWithDescriptor:options:reflection:error:",
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
error_check(error_pipeline_creation)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
if prod(local_size) > cast(int, max_total_threads):
exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "setComputePipelineState:", self.pipeline_state)
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(encoder, "endEncoding")
msg(command_buffer, "setLabel:", to_ns_str(self.name))
msg(command_buffer, "commit")
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
class MetalBuffer:
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
class MetalAllocator(LRUAllocator):
def __init__(self, dev:MetalDevice):
self.dev:MetalDevice = dev
super().__init__()
def _alloc(self, size:int, options) -> MetalBuffer:
# Buffer is explicitly released in _free() rather than garbage collected via reference count
ret = msg(self.dev.sysdevice, "newBufferWithLength:options:", ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared,
restype=objc_id)
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return MetalBuffer(ret, size)
def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
dest_dev.synchronize()
src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
msg(encoder, "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:", src.buf, ctypes.c_ulong(src.offset),
dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
msg(encoder, "endEncoding")
if src_dev != dest_dev:
msg(src_command_buffer, "encodeSignalEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer = msg(dest_dev.mtl_queue, "commandBuffer", restype=objc_instance)
msg(dest_command_buffer, "encodeWaitForEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
msg(dest_command_buffer, "commit")
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
src_dev.timeline_value += 1
msg(src_command_buffer, "commit")
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
def _cp_mv(self, dst, src, prof_desc):
with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
def _as_buffer(self, src:MetalBuffer) -> memoryview:
self.dev.synchronize()
return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
class MetalDevice(Compiled):
def __init__(self, device:str):
self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
self.mtl_queue = msg(self.sysdevice, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
self.mtl_buffers_in_flight: list[Any] = []
self.timeline_signal = msg(self.sysdevice, "newSharedEvent", restype=objc_instance)
self.timeline_value = 0
Compiled.profile_events += [ProfileDeviceEvent(device)]
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
functools.partial(MetalProgram, self), MetalGraph)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight:
wait_check(cbuf)
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
if PROFILE: Compiled.profile_events += [ProfileRangeEvent(self.device, cmdbuf_label(cbuf), st, en, is_copy=False)]
self.mtl_buffers_in_flight.clear()