Vehicle Researcher d64fb1838d openpilot v0.9.7 release
date: 2024-06-11T01:36:39
master commit: f8cb04e4a8b032b72a909f68b808a50936184bee
2024-06-11 16:32:27 -07:00

62 lines
3.3 KiB
Python

from tinygrad.helpers import dtypes, DType
from tinygrad.renderer.cstyle import CStyleLanguage
from typing import List, Union
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
import math
from typing import Tuple
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
class WGSLLanguage(CStyleLanguage):
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
size_prefix = "let"
barrier="workgroupBarrier();"
generic_var_prefix = "var "
external_local_bufs = True
code_for_op = {
UnaryOps.NEG: lambda x: f"(-{x})",
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})",
BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})",
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
}
def render_local(self, name: str, size: int):
return f"var<workgroup> {name}: array<f32,{size}>;"
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "nan()"
elif math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
else: val = f"({x}" + ("" if dtypes.is_int(var_dtype) else "f") + ")"
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
local_size = local_size[::-1] if local_size else [1]
bind_it = iter(range(len(bufs)))
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{type_map[dtype]}>;" for name,dtype in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
return prg
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
def render_if(self, cond: str):
return f"if (bool({cond})) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:
return f"select(f32({y}), {x}, bool({cond}))"
def render_cast(self, x:List[str], var_dtype:DType) -> str:
if type_map[var_dtype]: return f"{type_map[var_dtype]}({x[0]})"
raise NotImplementedError(f"no cast for {var_dtype}")
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
if buf_dtype != var_dtype:
var_name = f"{type_map[buf_dtype]}({var_name})"
return f"{buf_name}[{idx}] = {var_name};"