carrot/tinygrad_repo/examples/test_onnx_imagenet.py
carrot efee1712aa
KerryGoldModel, AGNOS12.3, ButtonMode3, autoDetectLFA2, (#181)
* fix.. speed_limit error...

* draw tpms settings.

* fix.. traffic light stopping only..

* fix.. waze cam

* fix.. waze...

* add setting (Enable comma connect )

* auto detect LFA2

* fix.. cruisespeed1

* vff2 driving model.

* fix..

* agnos 12.3

* fix..

* ff

* ff

* test

* ff

* fix.. drawTurnInfo..

* Update drive_helpers.py

* fix..

support eng  voice

eng sounds

fix settings... english

fix.. mph..

fix.. roadlimit speed bug..

* new vff model.. 250608

* fix soundd..

* fix safe exit speed..

* fix.. sounds.

* fix.. radar timeStep..

* KerryGold model

* Update drive_helpers.py

* fix.. model.

* fix..

* fix..

* Revert "fix.."

This reverts commit b09ec459afb855c533d47fd7e8a1a6b1a09466e7.

* Revert "fix.."

This reverts commit 290bec6b83a4554ca232d531a911edccf94a2156.

* fix esim

* add more acc table. 10kph

* kg update..

* fix cruisebutton mode3

* test atc..cond.

* fix.. canfd

* fix.. angle control limit
2025-06-13 15:59:36 +09:00

83 lines
3.7 KiB
Python

import random, sys
import numpy as np
from extra.datasets.imagenet import get_imagenet_categories, get_val_files, center_crop
from examples.benchmark_onnx import load_onnx_model
from PIL import Image
from tinygrad import Tensor, dtypes, GlobalCounters
from tinygrad.helpers import fetch, getenv
# works:
# ~70% - https://github.com/onnx/models/raw/refs/heads/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx
# ~43% - https://github.com/onnx/models/raw/refs/heads/main/Computer_Vision/alexnet_Opset16_torch_hub/alexnet_Opset16.onnx
# ~72% - https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# ~71% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenetv2_1.0.opt.onnx
# ~67% - https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7-quantized.onnx
# broken:
# https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
# ~35% - https://github.com/axinc-ai/onnx-quantization/raw/refs/heads/main/models/mobilenev2_quantized.onnx
# QUANT=1 python3 examples/test_onnx_imagenet.py
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# DONT_REALIZE_EXPAND=1 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
# VIZ=1 DONT_REALIZE_EXPAND=1 python3 examples/benchmark_onnx.py /tmp/model.quant.onnx
def imagenet_dataloader(cnt=0):
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
files = get_val_files()
random.shuffle(files)
files = files[:cnt]
cir = get_imagenet_categories()
for fn in files:
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
img = center_crop(img)
img = np.array(img)
img = Tensor(img).permute(2,0,1).reshape(1,3,224,224)
img = ((img.cast(dtypes.float32)/255.0) - input_mean) / input_std
y = cir[fn.split("/")[-2]]
yield img,y
if __name__ == "__main__":
fn = sys.argv[1]
if getenv("QUANT"):
from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantFormat, QuantType, CalibrationDataReader
model_fp32 = fetch(fn)
fn = '/tmp/model.quant.onnx'
if getenv("DYNAMIC"):
quantize_dynamic(model_fp32, fn)
else:
class ImagenetReader(CalibrationDataReader):
def __init__(self):
self.iter = imagenet_dataloader(cnt=1000)
def get_next(self) -> dict:
try:
img,y = next(self.iter)
except StopIteration:
return None
return {"input": img.numpy()}
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8,
extra_options={"ActivationSymmetric": False})
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))
t_name, t_spec = list(input_specs.items())[0]
assert t_spec.shape[1:] == (3,224,224), f"shape is {t_spec.shape}"
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader(cnt:=getenv("CNT", 100))):
GlobalCounters.reset()
p = run_onnx_jit(**{t_name:img})
assert p.shape == (1,1000)
t = p.to('cpu').argmax().item()
hit += y==t
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
MS_TARGET = 13.4
print(f"need {GlobalCounters.global_ops/1e9*(1000/MS_TARGET):.2f} GFLOPS for {MS_TARGET:.2f} ms")
if cnt >= 2:
import pickle
with open("/tmp/im.pkl", "wb") as f: pickle.dump(run_onnx_jit, f)