2025-04-19 08:05:49 +09:00

29 lines
1.4 KiB
Python

import yaml, argparse
from pathlib import Path
from huggingface_hub import snapshot_download
def download_models(yaml_file: str, download_dir: str) -> None:
with open(yaml_file, 'r') as f: metadata = yaml.safe_load(f)
n = len(metadata["repositories"])
for i, (model_id, model_data) in enumerate(metadata["repositories"].items()):
print(f"Downloading {i+1}/{n}: {model_id}...")
allow_patterns = [file_info["file"] for file_info in model_data["files"]]
root_path = Path(snapshot_download(repo_id=model_id, allow_patterns=allow_patterns, cache_dir=download_dir))
# download configs too (the sizes are small)
snapshot_download(repo_id=model_id, allow_patterns=["*config.json"], cache_dir=download_dir)
print(f"Downloaded model files to: {root_path}")
model_data["download_path"] = str(root_path)
# Save the updated metadata back to the YAML file
with open(yaml_file, 'w') as f: yaml.dump(metadata, f, sort_keys=False)
print("Download completed according to YAML file.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download models from Huggingface Hub based on a YAML configuration file.")
parser.add_argument("input", type=str, help="Path to the input YAML configuration file containing model information.")
args = parser.parse_args()
models_folder = Path(__file__).parent / "models"
models_folder.mkdir(parents=True, exist_ok=True)
download_models(args.input, str(models_folder))