mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Fix llamacpp_HF loading
This commit is contained in:
parent
7f5370a272
commit
8aeae3b3f4
@ -45,7 +45,7 @@ def llama_cpp_lib(model_file: Union[str, Path] = None):
|
|||||||
|
|
||||||
|
|
||||||
class LlamacppHF(PreTrainedModel):
|
class LlamacppHF(PreTrainedModel):
|
||||||
def __init__(self, model):
|
def __init__(self, model, path):
|
||||||
super().__init__(PretrainedConfig())
|
super().__init__(PretrainedConfig())
|
||||||
self.model = model
|
self.model = model
|
||||||
self.generation_config = GenerationConfig()
|
self.generation_config = GenerationConfig()
|
||||||
@ -64,7 +64,7 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
'n_tokens': self.model.n_tokens,
|
'n_tokens': self.model.n_tokens,
|
||||||
'input_ids': self.model.input_ids.copy(),
|
'input_ids': self.model.input_ids.copy(),
|
||||||
'scores': self.model.scores.copy(),
|
'scores': self.model.scores.copy(),
|
||||||
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params)
|
'ctx': llama_cpp_lib(path).llama_new_context_with_model(model.model, model.params)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _validate_model_class(self):
|
def _validate_model_class(self):
|
||||||
@ -217,4 +217,4 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
Llama = llama_cpp_lib(model_file).Llama
|
Llama = llama_cpp_lib(model_file).Llama
|
||||||
model = Llama(**params)
|
model = Llama(**params)
|
||||||
|
|
||||||
return LlamacppHF(model)
|
return LlamacppHF(model, model_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user