63 lines
30 KiB
Python
63 lines
30 KiB
Python
![]() |
# ruff: noqa: E501
|
||
|
import os
|
||
|
os.environ["VALIDATE_HCQ"]="1"
|
||
|
|
||
|
import unittest, random
|
||
|
import numpy as np
|
||
|
from tinygrad.codegen.kernel import Kernel, KernelOptError
|
||
|
from tinygrad.device import is_dtype_supported
|
||
|
from tinygrad.ops import UOp, Ops
|
||
|
from tinygrad.engine.search import Opt, OptOps
|
||
|
from tinygrad import Device, dtypes, Tensor
|
||
|
from test.external.fuzz_linearizer import compare_linearizer, compare_states, get_fuzz_rawbuf_like
|
||
|
|
||
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||
|
from tinygrad.shape.view import View
|
||
|
|
||
|
def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1e-2, atol=1e-2):
|
||
|
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
|
||
|
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
|
||
|
|
||
|
for opt in opts:
|
||
|
try:
|
||
|
lin.apply_opt(opt)
|
||
|
except KernelOptError:
|
||
|
# it's considered fixed if we invalidated the opts
|
||
|
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||
|
return
|
||
|
|
||
|
(msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(lin, rtol=rtol, atol=atol)
|
||
|
if msg in ["PASS", "KernelOptError"]:
|
||
|
# it's considered fixed if we invalidated the opts
|
||
|
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||
|
else:
|
||
|
assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {msg}"
|
||
|
|
||
|
validate_lin = lin.copy()
|
||
|
validate_lin.opts = validate_device.renderer
|
||
|
validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.dname) for x in rawbufs]
|
||
|
(_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol)
|
||
|
|
||
|
if _msg in ["PASS"] and compare_states(state1, state2):
|
||
|
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||
|
else:
|
||
|
assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {msg}"
|
||
|
|
||
|
return lin
|
||
|
|
||
|
class TestHCQFuzzFailures(unittest.TestCase):
|
||
|
def setUp(self):
|
||
|
random.seed(42)
|
||
|
np.random.seed(42)
|
||
|
Tensor.manual_seed(42)
|
||
|
|
||
|
@unittest.skipUnless(Device.DEFAULT in {"QCOM"}, "for QCOM")
|
||
|
def test_failure_1(self):
|
||
|
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=1, src=()), x39:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=((0, 1), (0, 6)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), x39,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||
|
|
||
|
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)]
|
||
|
helper_test_lin(Kernel(ast), opts, failed_platforms=[], validate_device=Device["GPU"])
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|