use dict for kv overrides

This commit is contained in:
phiharri 2024-02-13 13:13:50 +00:00
parent 673361bffc
commit e886490101
2 changed files with 32 additions and 2 deletions

View File

@ -201,6 +201,21 @@ class LlamacppHF(PreTrainedModel):
else: else:
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
kv_overrides = None
else:
kv_overrides = {}
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
t, v = tv.split(':')
if t == 'int':
kv_overrides[k] = int(v)
elif t == 'float':
kv_overrides[k] = float(v)
elif t == 'bool':
kv_overrides[k] = bool(v)
else:
raise ValueError('Invalid type for KV override')
params = { params = {
'model_path': str(model_file), 'model_path': str(model_file),
'n_ctx': shared.args.n_ctx, 'n_ctx': shared.args.n_ctx,
@ -218,7 +233,7 @@ class LlamacppHF(PreTrainedModel):
'logits_all': shared.args.logits_all, 'logits_all': shared.args.logits_all,
'offload_kqv': not shared.args.no_offload_kqv, 'offload_kqv': not shared.args.no_offload_kqv,
'split_mode': 1 if not shared.args.row_split else 2, 'split_mode': 1 if not shared.args.row_split else 2,
'kv_overrides': shared.args.kv_overrides,i 'kv_overrides': kv_overrides,
} }
Llama = llama_cpp_lib().Llama Llama = llama_cpp_lib().Llama

View File

@ -81,6 +81,21 @@ class LlamaCppModel:
else: else:
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
kv_overrides = None
else:
kv_overrides = {}
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
t, v = tv.split(':')
if t == 'int':
kv_overrides[k] = int(v)
elif t == 'float':
kv_overrides[k] = float(v)
elif t == 'bool':
kv_overrides[k] = bool(v)
else:
raise ValueError('Invalid type for KV override')
params = { params = {
'model_path': str(path), 'model_path': str(path),
'n_ctx': shared.args.n_ctx, 'n_ctx': shared.args.n_ctx,
@ -97,7 +112,7 @@ class LlamaCppModel:
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'offload_kqv': not shared.args.no_offload_kqv, 'offload_kqv': not shared.args.no_offload_kqv,
'split_mode': 1 if not shared.args.row_split else 2, 'split_mode': 1 if not shared.args.row_split else 2,
'kv_overrides': shared.args.kv_overrides, 'kv_overrides': kv_overrides,
} }
result.model = Llama(**params) result.model = Llama(**params)