Account for deprecated GGML parameters

This commit is contained in:
jllllll 2023-08-26 14:07:46 -05:00
parent 4a999e3bcd
commit 4d61a7d9da
No known key found for this signature in database
GPG Key ID: 7FCD00C417935797
2 changed files with 14 additions and 4 deletions

View File

@ -203,11 +203,16 @@ class LlamacppHF(PreTrainedModel):
'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base),
'tensor_split': tensor_split_list,
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
'logits_all': True,
}
if not is_gguf(model_file):
ggml_params = {
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
}
params = params | ggml_params
Llama = llama_cpp_lib(model_file).Llama
model = Llama(**params)

View File

@ -92,10 +92,15 @@ class LlamaCppModel:
'rope_freq_base': RoPE.get_rope_freq_base(shared.args.alpha_value, shared.args.rope_freq_base),
'tensor_split': tensor_split_list,
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
}
if not is_gguf(str(path)):
ggml_params = {
'n_gqa': shared.args.n_gqa or None,
'rms_norm_eps': shared.args.rms_norm_eps or None,
}
params = params | ggml_params
result.model = Llama(**params)
if cache_capacity > 0:
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))