137 lines
6.3 KiB
Python
Raw Normal View History

2025-04-18 20:38:55 +09:00
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
from pathlib import Path
from tinygrad.frontend.onnx import OnnxRunner
from extra.onnx import get_onnx_ops
from extra.onnx_helpers import validate, get_example_inputs
def get_config(root_path: Path):
ret = {}
for path in root_path.rglob("*config.json"):
config = json.load(path.open())
if isinstance(config, dict):
ret.update(config)
return ret
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
onnx_model = onnx.load(onnx_model_path)
onnx_runner = OnnxRunner(onnx_model)
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
def get_tolerances(file_name): # -> rtol, atol
# TODO very high rtol atol
if "fp16" in file_name: return 9e-2, 9e-2
if any(q in file_name for q in ["int8", "uint8", "quantized"]): return 4, 4
return 4e-3, 3e-2
def validate_repos(models:dict[str, tuple[Path, Path]]):
print(f"** Validating {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"validating model {model_id}")
model_path = root_path / relative_path
onnx_file_name = model_path.stem
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_file_name)
st = time.time()
run_huggingface_validate(model_path, config, rtol, atol)
et = time.time() - st
print(f"passed, took {et:.2f}s")
def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict:
ret = {}
op_counter = collections.Counter()
unsupported_ops = collections.defaultdict(set)
supported_ops = get_onnx_ops()
print(f"** Retrieving stats from {len(model_paths)} models **")
for model_id, (root_path, relative_path) in models.items():
print(f"examining {model_id}")
model_path = root_path / relative_path
onnx_runner = OnnxRunner(onnx.load(model_path))
for node in onnx_runner.graph_nodes:
op_counter[node.op] += 1
if node.op not in supported_ops:
unsupported_ops[node.op].add(model_id)
del onnx_runner
ret["unsupported_ops"] = {k:list(v) for k, v in unsupported_ops.items()}
ret["op_counter"] = op_counter.most_common()
return ret
def debug_run(model_path, truncate, config, rtol, atol):
if truncate != -1:
model = onnx.load(model_path)
nodes_up_to_limit = list(model.graph.node)[:truncate + 1]
new_output_values = [onnx.helper.make_empty_tensor_value_info(output_name) for output_name in nodes_up_to_limit[-1].output]
model.graph.ClearField("node")
model.graph.node.extend(nodes_up_to_limit)
model.graph.ClearField("output")
model.graph.output.extend(new_output_values)
with tempfile.NamedTemporaryFile(suffix=model_path.suffix) as tmp:
onnx.save(model, tmp.name)
run_huggingface_validate(tmp.name, config, rtol, atol)
else:
run_huggingface_validate(model_path, config, rtol, atol)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Huggingface ONNX Model Validator and Ops Checker")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
parser.add_argument("--check_ops", action="store_true", default=False,
help="Check support for ONNX operations in models from the YAML file")
parser.add_argument("--validate", action="store_true", default=False,
help="Validate correctness of models from the YAML file")
parser.add_argument("--debug", type=str, default="",
help="""Validates without explicitly needing a YAML or models pre-installed.
provide repo id (e.g. "minishlab/potion-base-8M") to validate all onnx models inside the repo
provide onnx model path (e.g. "minishlab/potion-base-8M/onnx/model.onnx") to validate only that one model
""")
parser.add_argument("--truncate", type=int, default=-1, help="Truncate the ONNX model so intermediate results can be validated")
args = parser.parse_args()
if not (args.check_ops or args.validate or args.debug):
parser.error("Please provide either --validate, --check_ops, or --debug.")
if args.truncate != -1 and not args.debug:
parser.error("--truncate and --debug should be used together for debugging")
if args.check_ops or args.validate:
with open(args.input, 'r') as f:
data = yaml.safe_load(f)
assert all(repo["download_path"] is not None for repo in data["repositories"].values()), "please run `download_models.py` for this yaml"
model_paths = {
model_id + "/" + model["file"]: (Path(repo["download_path"]), Path(model["file"]))
for model_id, repo in data["repositories"].items()
for model in repo["files"]
if model["file"].endswith(".onnx")
}
if args.check_ops:
pprint.pprint(retrieve_op_stats(model_paths))
if args.validate:
validate_repos(model_paths)
if args.debug:
from huggingface_hub import snapshot_download
download_dir = Path(__file__).parent / "models"
path:list[str] = args.debug.split("/")
if len(path) == 2:
# repo id
# validates all onnx models inside repo
repo_id = "/".join(path)
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=["*.onnx", ".onnx_data"], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
config = get_config(root_path)
for onnx_model in root_path.rglob("*.onnx"):
rtol, atol = get_tolerances(onnx_model.name)
print(f"validating {onnx_model.relative_to(root_path)} with truncate={args.truncate}, {rtol=}, {atol=}")
debug_run(onnx_model, -1, config, rtol, atol)
else:
# model id
# only validate the specified onnx model
onnx_model = path[-1]
assert path[-1].endswith(".onnx")
repo_id, relative_path = "/".join(path[:2]), "/".join(path[2:])
root_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=[relative_path], cache_dir=download_dir))
snapshot_download(repo_id=repo_id, allow_patterns=["*config.json"], cache_dir=download_dir)
config = get_config(root_path)
rtol, atol = get_tolerances(onnx_model)
print(f"validating {relative_path} with truncate={args.truncate}, {rtol=}, {atol=}")
debug_run(root_path / relative_path, args.truncate, config, rtol, atol)