Vehicle Researcher 4fca6dec8e openpilot v0.9.8 release
date: 2025-01-29T09:09:56
master commit: 227bb68e1891619b360b89809e6822d50d34228f
2025-01-29 09:09:58 +00:00

204 lines
7.0 KiB
Python

"""This is where the forwards and backwards passes live."""
import math
from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import Ops, resolve, sint, UOp
from tinygrad.tensor import Function
class Contiguous(Function):
def forward(self, x:UOp) -> UOp: return x.contiguous()
def backward(self, grad_output:UOp) -> UOp: return grad_output
class ContiguousBackward(Function):
def forward(self, x:UOp) -> UOp: return x
def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous()
class Cast(Function):
def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
def backward(self, grad_output:UOp) -> UOp:
if self.bitcast: raise RuntimeError("bitcast cannot backward")
return grad_output.cast(self.input_dtype)
# ************* unary ops *************
class Reciprocal(Function):
def forward(self, x:UOp) -> UOp:
self.ret = x.reciprocal()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return -grad_output * self.ret * self.ret
class Sin(Function):
def forward(self, x:UOp) -> UOp:
self.x = x
return x.sin()
def backward(self, grad_output:UOp) -> UOp: return (math.pi/2 - self.x).sin() * grad_output
class Relu(Function):
def forward(self, x:UOp) -> UOp:
self.ret = (x>0).where(x, 0)
return self.ret
def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output
class Log(Function):
def forward(self, x:UOp) -> UOp:
self.x = x
return x.log2() * math.log(2)
def backward(self, grad_output:UOp) -> UOp: return grad_output / self.x
class Exp(Function):
def forward(self, x:UOp) -> UOp:
self.ret = (x * (1/math.log(2))).exp2()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return self.ret * grad_output
class Sqrt(Function):
def forward(self, x:UOp) -> UOp:
self.ret = x.sqrt()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return grad_output / (self.ret*2)
class Sign(Function):
# NOTE: the x*0 is to match torch behavior without function.py
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
# backward always return 0 to match torch
def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0)
# ************* binary ops *************
class Less(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x<y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
class Neq(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y)
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
class Xor(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x^y
class BitwiseAnd(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x&y
class BitwiseOr(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x|y
class Threefry(Function):
def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed)
class Add(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x+y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output if self.needs_input_grad[1] else None
class Mul(Function):
def forward(self, x:UOp, y:UOp) -> UOp:
self.x, self.y = x, y
return x * y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
(self.x * grad_output) if self.needs_input_grad[1] else None
class IDiv(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
class Mod(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
# ************* ternary ops *************
class Where(Function):
def forward(self, x:UOp, y:UOp, z:UOp) -> UOp:
self.x = x
return self.x.where(y, z)
def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]:
return None, \
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.input_shape = x.shape
return x.r(Ops.ADD, axis)
def backward(self, grad_output:UOp) -> UOp: return grad_output.expand(self.input_shape)
class Prod(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.x, self.ret = x, x.r(Ops.MUL, axis)
return self.ret
def backward(self, grad_output:UOp) -> UOp:
return (grad_output * self.ret).expand(self.x.shape) / self.x
class Max(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
return self.ret
def backward(self, grad_output:UOp) -> UOp:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
return (max_is_1s/div) * grad_output.expand(self.x.shape)
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
return x.expand(shape)
def backward(self, grad_output:UOp) -> UOp:
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
class Reshape(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
self.input_shape = x.shape
return x.reshape(shape)
def backward(self, grad_output:UOp) -> UOp: return grad_output.reshape(self.input_shape)
class Permute(Function):
def forward(self, x:UOp, order:tuple[int, ...]) -> UOp:
self.input_order = order
return x.permute(order)
def backward(self, grad_output:UOp) -> UOp: return grad_output.permute(argsort(self.input_order))
class Pad(Function):
def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp:
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad(arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.shrink(self.narg)
class Shrink(Function):
def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.pad(self.narg)
class Flip(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
return x.stride(self.arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.stride(self.arg)