40 lines
1.6 KiB
Python
40 lines
1.6 KiB
Python
import sys, onnx, time
|
|
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
|
|
from tinygrad.tensor import _from_np_dtype
|
|
from extra.onnx import get_run_onnx
|
|
|
|
if __name__ == "__main__":
|
|
onnx_file = fetch(sys.argv[1])
|
|
print(onnx_file)
|
|
onnx_model = onnx.load(onnx_file)
|
|
Tensor.no_grad = True
|
|
Tensor.training = False
|
|
run_onnx = get_run_onnx(onnx_model)
|
|
print("loaded model")
|
|
|
|
# find preinitted tensors and ignore them
|
|
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}
|
|
expected_inputs = [inp for inp in onnx_model.graph.input if inp.name not in initted_tensors]
|
|
|
|
# get real inputs
|
|
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in expected_inputs}
|
|
input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in expected_inputs}
|
|
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
|
|
|
for i in range(3):
|
|
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
|
GlobalCounters.reset()
|
|
print(f"run {i}")
|
|
run_onnx_jit(**new_inputs)
|
|
|
|
# run 20 times
|
|
for _ in range(20):
|
|
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
|
GlobalCounters.reset()
|
|
st = time.perf_counter()
|
|
out = run_onnx_jit(**new_inputs)
|
|
mt = time.perf_counter()
|
|
val = out.numpy()
|
|
et = time.perf_counter()
|
|
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
|