use dict for kv overrides

This commit is contained in:
phiharri 2024-02-13 13:13:50 +00:00
parent 673361bffc
commit e886490101
2 changed files with 32 additions and 2 deletions

View File

@ -201,6 +201,21 @@ class LlamacppHF(PreTrainedModel):
else:
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
kv_overrides = None
else:
kv_overrides = {}
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
t, v = tv.split(':')
if t == 'int':
kv_overrides[k] = int(v)
elif t == 'float':
kv_overrides[k] = float(v)
elif t == 'bool':
kv_overrides[k] = bool(v)
else:
raise ValueError('Invalid type for KV override')
params = {
'model_path': str(model_file),
'n_ctx': shared.args.n_ctx,
@ -218,7 +233,7 @@ class LlamacppHF(PreTrainedModel):
'logits_all': shared.args.logits_all,
'offload_kqv': not shared.args.no_offload_kqv,
'split_mode': 1 if not shared.args.row_split else 2,
'kv_overrides': shared.args.kv_overrides,i
'kv_overrides': kv_overrides,
}
Llama = llama_cpp_lib().Llama

View File

@ -81,6 +81,21 @@ class LlamaCppModel:
else:
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
kv_overrides = None
else:
kv_overrides = {}
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
t, v = tv.split(':')
if t == 'int':
kv_overrides[k] = int(v)
elif t == 'float':
kv_overrides[k] = float(v)
elif t == 'bool':
kv_overrides[k] = bool(v)
else:
raise ValueError('Invalid type for KV override')
params = {
'model_path': str(path),
'n_ctx': shared.args.n_ctx,
@ -97,7 +112,7 @@ class LlamaCppModel:
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'offload_kqv': not shared.args.no_offload_kqv,
'split_mode': 1 if not shared.args.row_split else 2,
'kv_overrides': shared.args.kv_overrides,
'kv_overrides': kv_overrides,
}
result.model = Llama(**params)