mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-01 20:04:04 +01:00
use dict for kv overrides
This commit is contained in:
parent
673361bffc
commit
e886490101
@ -201,6 +201,21 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
||||||
|
|
||||||
|
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
|
||||||
|
kv_overrides = None
|
||||||
|
else:
|
||||||
|
kv_overrides = {}
|
||||||
|
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
|
||||||
|
t, v = tv.split(':')
|
||||||
|
if t == 'int':
|
||||||
|
kv_overrides[k] = int(v)
|
||||||
|
elif t == 'float':
|
||||||
|
kv_overrides[k] = float(v)
|
||||||
|
elif t == 'bool':
|
||||||
|
kv_overrides[k] = bool(v)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid type for KV override')
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'model_path': str(model_file),
|
'model_path': str(model_file),
|
||||||
'n_ctx': shared.args.n_ctx,
|
'n_ctx': shared.args.n_ctx,
|
||||||
@ -218,7 +233,7 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
'logits_all': shared.args.logits_all,
|
'logits_all': shared.args.logits_all,
|
||||||
'offload_kqv': not shared.args.no_offload_kqv,
|
'offload_kqv': not shared.args.no_offload_kqv,
|
||||||
'split_mode': 1 if not shared.args.row_split else 2,
|
'split_mode': 1 if not shared.args.row_split else 2,
|
||||||
'kv_overrides': shared.args.kv_overrides,i
|
'kv_overrides': kv_overrides,
|
||||||
}
|
}
|
||||||
|
|
||||||
Llama = llama_cpp_lib().Llama
|
Llama = llama_cpp_lib().Llama
|
||||||
|
@ -81,6 +81,21 @@ class LlamaCppModel:
|
|||||||
else:
|
else:
|
||||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
||||||
|
|
||||||
|
if shared.args.kv_overrides is None or shared.args.kv_overrides.strip() == '':
|
||||||
|
kv_overrides = None
|
||||||
|
else:
|
||||||
|
kv_overrides = {}
|
||||||
|
for k, tv in [x.split('=') for x in shared.args.kv_overrides.split()]:
|
||||||
|
t, v = tv.split(':')
|
||||||
|
if t == 'int':
|
||||||
|
kv_overrides[k] = int(v)
|
||||||
|
elif t == 'float':
|
||||||
|
kv_overrides[k] = float(v)
|
||||||
|
elif t == 'bool':
|
||||||
|
kv_overrides[k] = bool(v)
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid type for KV override')
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'model_path': str(path),
|
'model_path': str(path),
|
||||||
'n_ctx': shared.args.n_ctx,
|
'n_ctx': shared.args.n_ctx,
|
||||||
@ -97,7 +112,7 @@ class LlamaCppModel:
|
|||||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||||
'offload_kqv': not shared.args.no_offload_kqv,
|
'offload_kqv': not shared.args.no_offload_kqv,
|
||||||
'split_mode': 1 if not shared.args.row_split else 2,
|
'split_mode': 1 if not shared.args.row_split else 2,
|
||||||
'kv_overrides': shared.args.kv_overrides,
|
'kv_overrides': kv_overrides,
|
||||||
}
|
}
|
||||||
|
|
||||||
result.model = Llama(**params)
|
result.model = Llama(**params)
|
||||||
|
Loading…
Reference in New Issue
Block a user