llama.cpp/scripts/fetch_server_test_models.py
Olivier Chafik 8b576b6c55
Tool call support (generic + native for Llama, Functionary, Hermes, Mistral, Firefunction, DeepSeek) w/ lazy grammars (#9639)
---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-01-30 19:13:58 +00:00

106 lines
3.9 KiB
Python
Executable File

#!/usr/bin/env python
'''
This script fetches all the models used in the server tests.
This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.
It is meant to be run from the root of the repository.
Example:
python scripts/fetch_server_test_models.py
( cd examples/server/tests && ./tests.sh -v -x -m slow )
'''
import ast
import glob
import logging
import os
from typing import Generator
from pydantic import BaseModel
from typing import Optional
import subprocess
class HuggingFaceModel(BaseModel):
hf_repo: str
hf_file: Optional[str] = None
class Config:
frozen = True
def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
try:
with open(test_file) as f:
tree = ast.parse(f.read())
except Exception as e:
logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
return
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list:
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
param_names = ast.literal_eval(dec.args[0]).split(",")
if "hf_repo" not in param_names:
continue
raw_param_values = dec.args[1]
if not isinstance(raw_param_values, ast.List):
logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
continue
hf_repo_idx = param_names.index("hf_repo")
hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None
for t in raw_param_values.elts:
if not isinstance(t, ast.Tuple):
logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
continue
yield HuggingFaceModel(
hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
models = sorted(list(set([
model
for test_file in glob.glob('examples/server/tests/unit/test_*.py')
for model in collect_hf_model_test_parameters(test_file)
])), key=lambda m: (m.hf_repo, m.hf_file))
logging.info(f'Found {len(models)} models in parameterized tests:')
for m in models:
logging.info(f' - {m.hf_repo} / {m.hf_file}')
cli_path = os.environ.get(
'LLAMA_SERVER_BIN_PATH',
os.path.join(
os.path.dirname(__file__),
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
for m in models:
if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
continue
if m.hf_file is not None and '-of-' in m.hf_file:
logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
continue
logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
cmd = [
cli_path,
'-hfr', m.hf_repo,
*([] if m.hf_file is None else ['-hff', m.hf_file]),
'-n', '1',
'-p', 'Hey',
'--no-warmup',
'--log-disable',
'-no-cnv']
if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
cmd.append('-fa')
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError:
logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
exit(1)