mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
use dict for kv overrides
This commit is contained in:
parent
673361bffc
commit
e886490101
@ -201,6 +201,21 @@ class LlamacppHF(PreTrainedModel):
|
||||
else:
|
||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
||||
|
||||
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
|
||||
kv_overrides = None
|
||||
else:
|
||||
kv_overrides = {}
|
||||
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
|
||||
t, v = tv.split(':')
|
||||
if t == 'int':
|
||||
kv_overrides[k] = int(v)
|
||||
elif t == 'float':
|
||||
kv_overrides[k] = float(v)
|
||||
elif t == 'bool':
|
||||
kv_overrides[k] = bool(v)
|
||||
else:
|
||||
raise ValueError('Invalid type for KV override')
|
||||
|
||||
params = {
|
||||
'model_path': str(model_file),
|
||||
'n_ctx': shared.args.n_ctx,
|
||||
@ -218,7 +233,7 @@ class LlamacppHF(PreTrainedModel):
|
||||
'logits_all': shared.args.logits_all,
|
||||
'offload_kqv': not shared.args.no_offload_kqv,
|
||||
'split_mode': 1 if not shared.args.row_split else 2,
|
||||
'kv_overrides': shared.args.kv_overrides,i
|
||||
'kv_overrides': kv_overrides,
|
||||
}
|
||||
|
||||
Llama = llama_cpp_lib().Llama
|
||||
|
@ -81,6 +81,21 @@ class LlamaCppModel:
|
||||
else:
|
||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
||||
|
||||
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
|
||||
kv_overrides = None
|
||||
else:
|
||||
kv_overrides = {}
|
||||
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
|
||||
t, v = tv.split(':')
|
||||
if t == 'int':
|
||||
kv_overrides[k] = int(v)
|
||||
elif t == 'float':
|
||||
kv_overrides[k] = float(v)
|
||||
elif t == 'bool':
|
||||
kv_overrides[k] = bool(v)
|
||||
else:
|
||||
raise ValueError('Invalid type for KV override')
|
||||
|
||||
params = {
|
||||
'model_path': str(path),
|
||||
'n_ctx': shared.args.n_ctx,
|
||||
@ -97,7 +112,7 @@ class LlamaCppModel:
|
||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||
'offload_kqv': not shared.args.no_offload_kqv,
|
||||
'split_mode': 1 if not shared.args.row_split else 2,
|
||||
'kv_overrides': shared.args.kv_overrides,
|
||||
'kv_overrides': kv_overrides,
|
||||
}
|
||||
|
||||
result.model = Llama(**params)
|
||||
|
Loading…
Reference in New Issue
Block a user