carrot/tinygrad_repo/extra/hip_wrapper.py

693 lines
26 KiB
Python
Raw Permalink Normal View History

import ctypes
from tinygrad.helpers import DEBUG
import sys
import numpy as np
from typing import Any, Dict, List, Tuple
from dataclasses import dataclass
try:
_libhip = ctypes.cdll.LoadLibrary("libamdhip64.so")
_libhiprtc = ctypes.cdll.LoadLibrary("libhiprtc.so")
_libhip.hipGetErrorString.restype = ctypes.c_char_p
_libhip.hipGetErrorString.argtypes = [ctypes.c_int]
def hipGetErrorString(status):
return _libhip.hipGetErrorString(status).decode("utf-8")
def hipCheckStatus(status):
if status != 0: raise RuntimeError("HIP error %s: %s" % (status, hipGetErrorString(status)))
_libhip.hipDeviceSynchronize.restype = int
_libhip.hipDeviceSynchronize.argtypes = []
def hipDeviceSynchronize():
status = _libhip.hipDeviceSynchronize()
hipCheckStatus(status)
_libhip.hipStreamSynchronize.restype = int
_libhip.hipStreamSynchronize.argtypes = [ctypes.c_void_p]
def hipStreamSynchronize(stream):
status = _libhip.hipStreamSynchronize(stream)
hipCheckStatus(status)
_libhip.hipEventCreate.restype = int
_libhip.hipEventCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
def hipEventCreate():
ptr = ctypes.c_void_p()
status = _libhip.hipEventCreate(ctypes.byref(ptr))
hipCheckStatus(status)
return ptr
_libhip.hipEventRecord.restype = int
_libhip.hipEventRecord.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
def hipEventRecord(event, stream=None):
status = _libhip.hipEventRecord(event, stream)
hipCheckStatus(status)
_libhip.hipEventDestroy.restype = int
_libhip.hipEventDestroy.argtypes = [ctypes.c_void_p]
def hipEventDestroy(event):
status = _libhip.hipEventDestroy(event)
hipCheckStatus(status)
_libhip.hipEventSynchronize.restype = int
_libhip.hipEventSynchronize.argtypes = [ctypes.c_void_p]
def hipEventSynchronize(event):
status = _libhip.hipEventSynchronize(event)
hipCheckStatus(status)
_libhip.hipEventElapsedTime.restype = int
_libhip.hipEventElapsedTime.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_void_p, ctypes.c_void_p]
def hipEventElapsedTime(start, stop):
t = ctypes.c_float()
status = _libhip.hipEventElapsedTime(ctypes.byref(t), start, stop)
hipCheckStatus(status)
return t.value
## Stream Management
# Stream capture modes:
hipStreamCaptureModeGlobal = 0
hipStreamCaptureModeThreadLocal = 1
hipStreamCaptureModeRelaxed = 2
_libhip.hipStreamCreate.restype = int
_libhip.hipStreamCreate.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
def hipStreamCreate():
ptr = ctypes.c_void_p()
status = _libhip.hipStreamCreate(ctypes.byref(ptr))
hipCheckStatus(status)
return ptr
_libhip.hipStreamDestroy.restype = int
_libhip.hipStreamDestroy.argtypes = [ctypes.c_void_p]
def hipStreamDestroy(stream):
status = _libhip.hipStreamDestroy(stream)
hipCheckStatus(status)
_libhip.hipStreamBeginCapture.restype = int
_libhip.hipStreamBeginCapture.argtypes = [ctypes.c_void_p, ctypes.c_int]
def hipStreamBeginCapture(stream, mode=hipStreamCaptureModeGlobal):
t = ctypes.c_float()
status = _libhip.hipStreamBeginCapture(stream, mode)
hipCheckStatus(status)
_libhip.hipStreamEndCapture.restype = int
_libhip.hipStreamEndCapture.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
def hipStreamEndCapture(stream):
ptr = ctypes.c_void_p()
status = _libhip.hipStreamEndCapture(stream, ctypes.byref(ptr))
hipCheckStatus(status)
return ptr
_libhip.hipStreamGetCaptureInfo_v2.restype = int
_libhip.hipStreamGetCaptureInfo_v2.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
def hipStreamGetCaptureInfo_v2(stream):
status_out = ctypes.c_void_p()
id_out = ctypes.c_ulonglong()
graph_out = ctypes.c_void_p()
deps_out = ctypes.POINTER(ctypes.c_void_p)()
num_deps = ctypes.c_size_t()
status = _libhip.hipStreamGetCaptureInfo_v2(stream, ctypes.byref(status_out), ctypes.byref(id_out), ctypes.byref(graph_out), ctypes.byref(deps_out), ctypes.byref(num_deps))
hipCheckStatus(status)
deps = [ctypes.cast(deps_out[i], ctypes.c_void_p) for i in range(num_deps.value)]
return status_out, id_out.value, graph_out, deps
hipStreamAddCaptureDependencies = 0
hipStreamSetCaptureDependencies = 1
_libhip.hipStreamUpdateCaptureDependencies.restype = int
_libhip.hipStreamUpdateCaptureDependencies.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
def hipStreamUpdateCaptureDependencies(stream, deps, flags=hipStreamAddCaptureDependencies):
deps_in = (ctypes.c_void_p * len(deps))()
deps_in[:] = deps
num_deps = ctypes.c_size_t()
num_deps.value = len(deps)
flags_in = ctypes.c_uint()
flags_in.value = flags
status = _libhip.hipStreamUpdateCaptureDependencies(stream, deps_in, num_deps, flags_in)
hipCheckStatus(status)
## Graph Management
_libhip.hipGraphCreate.restype = int
_libhip.hipGraphCreate.argtypes = [ctypes.c_void_p, ctypes.c_uint]
def hipGraphCreate():
ptr = ctypes.c_void_p()
status = _libhip.hipGraphCreate(ctypes.byref(ptr), 0)
hipCheckStatus(status)
return ptr
_libhip.hipGraphInstantiate.restype = int
_libhip.hipGraphInstantiate.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
def hipGraphInstantiate(graph):
ptr = ctypes.c_void_p()
status = _libhip.hipGraphInstantiate(ctypes.byref(ptr), graph, 0, 0, 0)
hipCheckStatus(status)
return ptr
_libhip.hipGraphDestroy.restype = int
_libhip.hipGraphDestroy.argtypes = [ctypes.c_void_p]
def hipGraphDestroy(graph):
status = _libhip.hipGraphDestroy(graph)
hipCheckStatus(status)
_libhip.hipGraphExecDestroy.restype = int
_libhip.hipGraphExecDestroy.argtypes = [ctypes.c_void_p]
def hipGraphExecDestroy(gexec):
status = _libhip.hipGraphExecDestroy(gexec)
hipCheckStatus(status)
_libhip.hipGraphLaunch.restype = int
_libhip.hipGraphLaunch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
def hipGraphLaunch(graph_exec, stream=0):
status = _libhip.hipGraphLaunch(graph_exec, stream)
hipCheckStatus(status)
class hipKernelNodeParams(ctypes.Structure):
_fields_ = [("blockDimX", ctypes.c_uint32), ("blockDimY", ctypes.c_uint32), ("blockDimZ", ctypes.c_uint32),
("extra", ctypes.POINTER(ctypes.c_void_p)),
("func", ctypes.c_void_p),
("gridDimX", ctypes.c_uint32), ("gridDimY", ctypes.c_uint32), ("gridDimZ", ctypes.c_uint32),
("kernelParams", ctypes.POINTER(ctypes.c_void_p)),
("sharedMemBytes", ctypes.c_uint32)]
@dataclass
class kernelNodeParamsWrapper():
c_struct: Any
context: Any = None
# Better to cache struct_types since they reused often and take a lot of time to create.
struct_type_cache: Dict[str, Any] = {}
def __get_struct(name, field_list):
global struct_type_cache
if name in struct_type_cache:
return struct_type_cache[name]
class CStructure(ctypes.Structure):
_fields_ = field_list
struct_type_cache[name] = CStructure
return struct_type_cache[name]
def getStructTypeForArgs(*args):
types = ""
fields: List[Tuple[str, Any]] = []
for idx in range(len(args)):
if args[idx].__class__ is int:
types += 'i'
fields.append((f'field{idx}', ctypes.c_int))
else:
types += 'P'
fields.append((f'field{idx}', ctypes.c_void_p))
return __get_struct(types, fields)
def updateKernelNodeParams(npwrapper:kernelNodeParamsWrapper, *args, grid=(1,1,1), block=(1,1,1), updated_args=None):
_, struct, _ = npwrapper.context
if updated_args is not None:
for i in updated_args:
setattr(struct, f'field{i}', (args[i] if args[i].__class__ is int else args[i]._buf))
else:
for i,d in enumerate(args):
setattr(struct, f'field{i}', (d if d.__class__ is int else d._buf))
npwrapper.c_struct.blockDimX = block[0]
npwrapper.c_struct.blockDimY = block[1]
npwrapper.c_struct.blockDimZ = block[2]
npwrapper.c_struct.gridDimX = grid[0]
npwrapper.c_struct.gridDimY = grid[1]
npwrapper.c_struct.gridDimZ = grid[2]
def buildKernelNodeParams(*args, func=None, grid=(1,1,1), block=(1,1,1), sharedMemBytes=0, argsStructType=None):
data = [d if d.__class__ is int else d._buf for d in args]
if argsStructType is None: argsStructType = getStructTypeForArgs(*args)
struct = argsStructType(*data)
size = ctypes.c_size_t(ctypes.sizeof(struct))
p_size = ctypes.c_void_p(ctypes.addressof(size))
p_struct = ctypes.c_void_p(ctypes.addressof(struct))
config = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), p_struct,
ctypes.c_void_p(2), p_size, ctypes.c_void_p(3))
params = hipKernelNodeParams(block[0], block[1], block[2], config, func, grid[0], grid[1], grid[2], None, sharedMemBytes)
return kernelNodeParamsWrapper(c_struct=params, context=(size, struct, config))
_libhip.hipGraphAddKernelNode.restype = int
_libhip.hipGraphAddKernelNode.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p]
def hipGraphAddKernelNode(graph, deps, params:kernelNodeParamsWrapper):
graph_node = ctypes.c_void_p()
deps_in = (ctypes.c_void_p * len(deps))()
deps_in[:] = deps
num_deps = ctypes.c_size_t(len(deps))
status = _libhip.hipGraphAddKernelNode(ctypes.byref(graph_node), graph, deps_in, num_deps, ctypes.byref(params.c_struct))
hipCheckStatus(status)
return graph_node
_libhip.hipGraphExecKernelNodeSetParams.restype = int
_libhip.hipGraphExecKernelNodeSetParams.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
def hipGraphExecKernelNodeSetParams(gexec, node, params:kernelNodeParamsWrapper):
status = _libhip.hipGraphExecKernelNodeSetParams(gexec, node, ctypes.byref(params.c_struct))
hipCheckStatus(status)
_libhip.hipMalloc.restype = int
_libhip.hipMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
def hipMalloc(count):
ptr = ctypes.c_void_p()
status = _libhip.hipMalloc(ctypes.byref(ptr), count)
hipCheckStatus(status)
return ptr.value
_libhip.hipFree.restype = int
_libhip.hipFree.argtypes = [ctypes.c_void_p]
def hipFree(ptr):
status = _libhip.hipFree(ptr)
hipCheckStatus(status)
# memory copy modes
hipMemcpyHostToHost = 0
hipMemcpyHostToDevice = 1
hipMemcpyDeviceToHost = 2
hipMemcpyDeviceToDevice = 3
hipMemcpyDefault = 4
_libhip.hipMemcpy.restype = int
_libhip.hipMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
def hipMemcpy(dst, src, count, direction):
status = _libhip.hipMemcpy(dst, src, ctypes.c_size_t(count), direction)
hipCheckStatus(status)
_libhip.hipMemcpyAsync.restype = int
_libhip.hipMemcpyAsync.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_void_p]
def hipMemcpyAsync(dst, src, count, direction, stream):
status = _libhip.hipMemcpyAsync(dst, src, ctypes.c_size_t(count), direction, stream)
hipCheckStatus(status)
_libhip.hipDeviceEnablePeerAccess.restype = int
_libhip.hipDeviceEnablePeerAccess.argtypes = [ctypes.c_int, ctypes.c_uint]
def hipDeviceEnablePeerAccess(peerDevice, flags):
status = _libhip.hipDeviceEnablePeerAccess(peerDevice, flags)
hipCheckStatus(status)
_libhip.hipMemGetInfo.restype = int
_libhip.hipMemGetInfo.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
def hipMemGetInfo():
free = ctypes.c_size_t()
total = ctypes.c_size_t()
status = _libhip.hipMemGetInfo(ctypes.byref(free), ctypes.byref(total))
hipCheckStatus(status)
return free.value, total.value
class hipIpcMemHandle_t(ctypes.Structure):
_fields_ = [("reserved", ctypes.c_char * 64)]
_libhip.hipIpcGetMemHandle.restype = int
_libhip.hipIpcGetMemHandle.argtypes = [ctypes.POINTER(hipIpcMemHandle_t), ctypes.c_void_p]
def hipIpcGetMemHandle(ptr):
handle = hipIpcMemHandle_t()
status = _libhip.hipIpcGetMemHandle(ctypes.byref(handle), ptr)
hipCheckStatus(status)
return handle
_libhip.hipIpcOpenMemHandle.restype = int
_libhip.hipIpcOpenMemHandle.argtypes = [ctypes.POINTER(ctypes.c_void_p), hipIpcMemHandle_t, ctypes.c_uint]
def hipIpcOpenMemHandle(handle, flags):
ptr = ctypes.c_void_p()
status = _libhip.hipIpcOpenMemHandle(ctypes.byref(ptr), handle, flags)
hipCheckStatus(status)
return ptr.value
_libhip.hipIpcCloseMemHandle.restype = int
_libhip.hipIpcCloseMemHandle.argtypes = [ctypes.c_void_p]
def hipIpcCloseMemHandle(ptr):
status = _libhip.hipIpcCloseMemHandle(ptr)
hipCheckStatus(status)
_libhip.hipSetDevice.restype = int
_libhip.hipSetDevice.argtypes = [ctypes.c_int]
def hipSetDevice(dev):
status = _libhip.hipSetDevice(dev)
hipCheckStatus(status)
_libhip.hipGetDevice.restype = int
_libhip.hipGetDevice.argtypes = [ctypes.POINTER(ctypes.c_int)]
def hipGetDevice():
dev = ctypes.c_int()
status = _libhip.hipGetDevice(ctypes.byref(dev))
hipCheckStatus(status)
return dev.value
_libhip.hipGetDeviceCount.restype = int
_libhip.hipGetDeviceCount.argtypes = [ctypes.POINTER(ctypes.c_int)]
def hipGetDeviceCount():
count = ctypes.c_int()
status = _libhip.hipGetDeviceCount(ctypes.byref(count))
hipCheckStatus(status)
return count.value
class hipDeviceArch(ctypes.Structure):
_fields_ = [
# *32-bit Atomics*
# 32-bit integer atomics for global memory.
("hasGlobalInt32Atomics", ctypes.c_uint, 1),
# 32-bit float atomic exch for global memory.
("hasGlobalFloatAtomicExch", ctypes.c_uint, 1),
# 32-bit integer atomics for shared memory.
("hasSharedInt32Atomics", ctypes.c_uint, 1),
# 32-bit float atomic exch for shared memory.
("hasSharedFloatAtomicExch", ctypes.c_uint, 1),
# 32-bit float atomic add in global and shared memory.
("hasFloatAtomicAdd", ctypes.c_uint, 1),
# *64-bit Atomics*
# 64-bit integer atomics for global memory.
("hasGlobalInt64Atomics", ctypes.c_uint, 1),
# 64-bit integer atomics for shared memory.
("hasSharedInt64Atomics", ctypes.c_uint, 1),
# *Doubles*
# Double-precision floating point.
("hasDoubles", ctypes.c_uint, 1),
# *Warp cross-lane operations*
# Warp vote instructions (__any, __all).
("hasWarpVote", ctypes.c_uint, 1),
# Warp ballot instructions (__ballot).
("hasWarpBallot", ctypes.c_uint, 1),
# Warp shuffle operations. (__shfl_*).
("hasWarpShuffle", ctypes.c_uint, 1),
# Funnel two words into one with shift&mask caps.
("hasFunnelShift", ctypes.c_uint, 1),
# *Sync*
# __threadfence_system.
("hasThreadFenceSystem", ctypes.c_uint, 1),
# __syncthreads_count, syncthreads_and, syncthreads_or.
("hasSyncThreadsExt", ctypes.c_uint, 1),
# *Misc*
# Surface functions.
("hasSurfaceFuncs", ctypes.c_uint, 1),
# Grid and group dims are 3D (rather than 2D).
("has3dGrid", ctypes.c_uint, 1),
# Dynamic parallelism.
("hasDynamicParallelism", ctypes.c_uint, 1),
]
class hipDeviceProperties(ctypes.Structure):
_fields_ = [
# Device name
("_name", ctypes.c_char * 256),
# Size of global memory region (in bytes)
("totalGlobalMem", ctypes.c_size_t),
# Size of shared memory region (in bytes).
("sharedMemPerBlock", ctypes.c_size_t),
# Registers per block.
("regsPerBlock", ctypes.c_int),
# Warp size.
("warpSize", ctypes.c_int),
# Max work items per work group or workgroup max size.
("maxThreadsPerBlock", ctypes.c_int),
# Max number of threads in each dimension (XYZ) of a block.
("maxThreadsDim", ctypes.c_int * 3),
# Max grid dimensions (XYZ).
("maxGridSize", ctypes.c_int * 3),
# Max clock frequency of the multiProcessors in khz.
("clockRate", ctypes.c_int),
# Max global memory clock frequency in khz.
("memoryClockRate", ctypes.c_int),
# Global memory bus width in bits.
("memoryBusWidth", ctypes.c_int),
# Size of shared memory region (in bytes).
("totalConstMem", ctypes.c_size_t),
# Major compute capability. On HCC, this is an approximation and features may
# differ from CUDA CC. See the arch feature flags for portable ways to query
# feature caps.
("major", ctypes.c_int),
# Minor compute capability. On HCC, this is an approximation and features may
# differ from CUDA CC. See the arch feature flags for portable ways to query
# feature caps.
("minor", ctypes.c_int),
# Number of multi-processors (compute units).
("multiProcessorCount", ctypes.c_int),
# L2 cache size.
("l2CacheSize", ctypes.c_int),
# Maximum resident threads per multi-processor.
("maxThreadsPerMultiProcessor", ctypes.c_int),
# Compute mode.
("computeMode", ctypes.c_int),
# Frequency in khz of the timer used by the device-side "clock*"
# instructions. New for HIP.
("clockInstructionRate", ctypes.c_int),
# Architectural feature flags. New for HIP.
("arch", hipDeviceArch),
# Device can possibly execute multiple kernels concurrently.
("concurrentKernels", ctypes.c_int),
# PCI Domain ID
("pciDomainID", ctypes.c_int),
# PCI Bus ID.
("pciBusID", ctypes.c_int),
# PCI Device ID.
("pciDeviceID", ctypes.c_int),
# Maximum Shared Memory Per Multiprocessor.
("maxSharedMemoryPerMultiProcessor", ctypes.c_size_t),
# 1 if device is on a multi-GPU board, 0 if not.
("isMultiGpuBoard", ctypes.c_int),
# Check whether HIP can map host memory
("canMapHostMemory", ctypes.c_int),
# DEPRECATED: use gcnArchName instead
("gcnArch", ctypes.c_int),
# AMD GCN Arch Name.
("_gcnArchName", ctypes.c_char * 256),
# APU vs dGPU
("integrated", ctypes.c_int),
# HIP device supports cooperative launch
("cooperativeLaunch", ctypes.c_int),
# HIP device supports cooperative launch on multiple devices
("cooperativeMultiDeviceLaunch", ctypes.c_int),
# Maximum size for 1D textures bound to linear memory
("maxTexture1DLinear", ctypes.c_int),
# Maximum number of elements in 1D images
("maxTexture1D", ctypes.c_int),
# Maximum dimensions (width, height) of 2D images, in image elements
("maxTexture2D", ctypes.c_int * 2),
# Maximum dimensions (width, height, depth) of 3D images, in image elements
("maxTexture3D", ctypes.c_int * 3),
# Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register
("hdpMemFlushCntl", ctypes.POINTER(ctypes.c_uint)),
# Addres of HDP_REG_COHERENCY_FLUSH_CNTL register
("hdpRegFlushCntl", ctypes.POINTER(ctypes.c_uint)),
# Maximum pitch in bytes allowed by memory copies
("memPitch", ctypes.c_size_t),
# Alignment requirement for textures
("textureAlignment", ctypes.c_size_t),
# Pitch alignment requirement for texture references bound to pitched memory
("texturePitchAlignment", ctypes.c_size_t),
# Run time limit for kernels executed on the device
("kernelExecTimeoutEnabled", ctypes.c_int),
# Device has ECC support enabled
("ECCEnabled", ctypes.c_int),
# 1:If device is Tesla device using TCC driver, else 0
("tccDriver", ctypes.c_int),
# HIP device supports cooperative launch on multiple
# devices with unmatched functions
("cooperativeMultiDeviceUnmatchedFunc", ctypes.c_int),
# HIP device supports cooperative launch on multiple
# devices with unmatched grid dimensions
("cooperativeMultiDeviceUnmatchedGridDim", ctypes.c_int),
# HIP device supports cooperative launch on multiple
# devices with unmatched block dimensions
("cooperativeMultiDeviceUnmatchedBlockDim", ctypes.c_int),
# HIP device supports cooperative launch on multiple
# devices with unmatched shared memories
("cooperativeMultiDeviceUnmatchedSharedMem", ctypes.c_int),
# 1: if it is a large PCI bar device, else 0
("isLargeBar", ctypes.c_int),
# Revision of the GPU in this device
("asicRevision", ctypes.c_int),
# Device supports allocating managed memory on this system
("managedMemory", ctypes.c_int),
# Host can directly access managed memory on the device without migration
("directManagedMemAccessFromHost", ctypes.c_int),
# Device can coherently access managed memory concurrently with the CPU
("concurrentManagedAccess", ctypes.c_int),
# Device supports coherently accessing pageable memory
# without calling hipHostRegister on it
("pageableMemoryAccess", ctypes.c_int),
# Device accesses pageable memory via the host"s page tables
("pageableMemoryAccessUsesHostPageTables", ctypes.c_int),
]
@property
def name(self):
return self._name.decode("utf-8")
@property
def gcnArchName(self):
return self._gcnArchName.decode("utf-8")
_libhip.hipGetDeviceProperties.restype = int
_libhip.hipGetDeviceProperties.argtypes = [ctypes.POINTER(hipDeviceProperties), ctypes.c_int]
def hipGetDeviceProperties(deviceId: int):
device_properties = hipDeviceProperties()
status = _libhip.hipGetDeviceProperties(ctypes.pointer(device_properties), deviceId)
hipCheckStatus(status)
return device_properties
_libhip.hipModuleLoadData.restype = int
_libhip.hipModuleLoadData.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p]
def hipModuleLoadData(data):
module = ctypes.c_void_p()
status = _libhip.hipModuleLoadData(ctypes.byref(module), data)
hipCheckStatus(status)
return module
_libhip.hipModuleGetFunction.restype = int
_libhip.hipModuleGetFunction.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
def hipModuleGetFunction(module, func_name):
kernel = ctypes.c_void_p()
status = _libhip.hipModuleGetFunction(ctypes.byref(kernel), module, func_name.encode("utf-8"))
hipCheckStatus(status)
return kernel
_libhip.hipModuleUnload.restype = int
_libhip.hipModuleUnload.argtypes = [ctypes.c_void_p]
def hipModuleUnload(module):
status = _libhip.hipModuleUnload(module)
hipCheckStatus(status)
_libhip.hipModuleLaunchKernel.restype = int
_libhip.hipModuleLaunchKernel.argtypes = [ctypes.c_void_p,
ctypes.c_uint, ctypes.c_uint, ctypes.c_uint, # block dim
ctypes.c_uint, ctypes.c_uint, ctypes.c_uint, # thread dim
ctypes.c_uint, ctypes.c_void_p,
ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_void_p)]
def hipModuleLaunchKernel(kernel, bx, by, bz, tx, ty, tz, shared, stream, struct):
c_bx, c_by, c_bz = ctypes.c_uint(bx), ctypes.c_uint(by), ctypes.c_uint(bz)
c_tx, c_ty, c_tz = ctypes.c_uint(tx), ctypes.c_uint(ty), ctypes.c_uint(tz)
c_shared = ctypes.c_uint(shared)
param_buffer_ptr, param_buffer_size, param_buffer_end = ctypes.c_void_p(1), ctypes.c_void_p(2), ctypes.c_void_p(3)
size = ctypes.c_size_t(ctypes.sizeof(struct))
p_size, p_struct = ctypes.c_void_p(ctypes.addressof(size)), ctypes.c_void_p(ctypes.addressof(struct))
config = (ctypes.c_void_p * 5)(param_buffer_ptr, p_struct, param_buffer_size, p_size, param_buffer_end)
status = _libhip.hipModuleLaunchKernel(kernel, c_bx, c_by, c_bz, c_tx, c_ty, c_tz, c_shared, stream, None, config)
hipCheckStatus(status)
_libhiprtc.hiprtcCreateProgram.restype = int
_libhiprtc.hiprtcCreateProgram.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_char), ctypes.POINTER(ctypes.c_char),
ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_char_p)]
def hiprtcCreateProgram(source, name, header_names, header_sources):
c_header_names, c_header_sources = (ctypes.c_char_p * len(header_names))(), (ctypes.c_char_p * len(header_sources))()
c_header_names[:], c_header_sources[:] = [h.encode("utf-8") for h in header_names], [h.encode("utf-8") for h in header_sources]
prog = ctypes.c_void_p()
status = _libhiprtc.hiprtcCreateProgram(ctypes.byref(prog), source.encode("utf-8"), name.encode("utf-8"), len(header_names), c_header_sources, c_header_names)
hipCheckStatus(status)
return prog
_libhiprtc.hiprtcDestroyProgram.restype = int
_libhiprtc.hiprtcDestroyProgram.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
def hiprtcDestroyProgram(prog):
status = _libhiprtc.hiprtcDestroyProgram(ctypes.byref(prog))
hipCheckStatus(status)
_libhiprtc.hiprtcGetProgramLogSize.restype = int
_libhiprtc.hiprtcGetProgramLogSize.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
_libhiprtc.hiprtcGetProgramLog.restype = int
_libhiprtc.hiprtcGetProgramLog.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
def hiprtcGetProgramLog(prog):
logsz = ctypes.c_size_t()
status = _libhiprtc.hiprtcGetProgramLogSize(prog, logsz)
hipCheckStatus(status)
logstr = ctypes.create_string_buffer(logsz.value)
status = _libhiprtc.hiprtcGetProgramLog(prog, logstr)
hipCheckStatus(status)
return logstr.value.decode()
_libhiprtc.hiprtcCompileProgram.restype = int
_libhiprtc.hiprtcCompileProgram.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)]
def hiprtcCompileProgram(prog, options):
c_options = (ctypes.c_char_p * len(options))()
c_options[:] = [o.encode("utf-8") for o in options]
status = _libhiprtc.hiprtcCompileProgram(prog, len(options), c_options)
if status == 6: print(hiprtcGetProgramLog(prog))
hipCheckStatus(status)
_libhiprtc.hiprtcGetCodeSize.restype = int
_libhiprtc.hiprtcGetCodeSize.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
_libhiprtc.hiprtcGetCode.restype = int
_libhiprtc.hiprtcGetCode.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
def hiprtcGetCode(prog):
code_size = ctypes.c_size_t()
status = _libhiprtc.hiprtcGetCodeSize(prog, ctypes.byref(code_size))
hipCheckStatus(status)
e_code = ("0" * code_size.value).encode("utf-8")
status = _libhiprtc.hiprtcGetCode(prog, e_code)
hipCheckStatus(status)
return e_code
except:
if DEBUG >= 1: print("WARNING: libamdhip64.so or libhiprtc.so not found. HIP support will not work.")