mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix llamacpp_HF loading
This commit is contained in:
parent
7f5370a272
commit
8aeae3b3f4
@ -45,7 +45,7 @@ def llama_cpp_lib(model_file: Union[str, Path] = None):
|
||||
|
||||
|
||||
class LlamacppHF(PreTrainedModel):
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, path):
|
||||
super().__init__(PretrainedConfig())
|
||||
self.model = model
|
||||
self.generation_config = GenerationConfig()
|
||||
@ -64,7 +64,7 @@ class LlamacppHF(PreTrainedModel):
|
||||
'n_tokens': self.model.n_tokens,
|
||||
'input_ids': self.model.input_ids.copy(),
|
||||
'scores': self.model.scores.copy(),
|
||||
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params)
|
||||
'ctx': llama_cpp_lib(path).llama_new_context_with_model(model.model, model.params)
|
||||
}
|
||||
|
||||
def _validate_model_class(self):
|
||||
@ -217,4 +217,4 @@ class LlamacppHF(PreTrainedModel):
|
||||
Llama = llama_cpp_lib(model_file).Llama
|
||||
model = Llama(**params)
|
||||
|
||||
return LlamacppHF(model)
|
||||
return LlamacppHF(model, model_file)
|
||||
|
Loading…
Reference in New Issue
Block a user