import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes from tinygrad.ops import Device # TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul def bit_extract(x, s, e) -> Tensor: # extract the top bits we don't want top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1)) x = (x - top_bits) / (1<