Fix llamacpp_HF loading

This commit is contained in:
oobabooga 2023-08-26 22:15:06 -07:00
parent 7f5370a272
commit 8aeae3b3f4

View File

@ -45,7 +45,7 @@ def llama_cpp_lib(model_file: Union[str, Path] = None):
class LlamacppHF(PreTrainedModel): class LlamacppHF(PreTrainedModel):
def __init__(self, model): def __init__(self, model, path):
super().__init__(PretrainedConfig()) super().__init__(PretrainedConfig())
self.model = model self.model = model
self.generation_config = GenerationConfig() self.generation_config = GenerationConfig()
@ -64,7 +64,7 @@ class LlamacppHF(PreTrainedModel):
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids.copy(), 'input_ids': self.model.input_ids.copy(),
'scores': self.model.scores.copy(), 'scores': self.model.scores.copy(),
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params) 'ctx': llama_cpp_lib(path).llama_new_context_with_model(model.model, model.params)
} }
def _validate_model_class(self): def _validate_model_class(self):
@ -217,4 +217,4 @@ class LlamacppHF(PreTrainedModel):
Llama = llama_cpp_lib(model_file).Llama Llama = llama_cpp_lib(model_file).Llama
model = Llama(**params) model = Llama(**params)
return LlamacppHF(model) return LlamacppHF(model, model_file)