Fix CFG with llamacpp_HF (2nd attempt)

This commit is contained in:
oobabooga 2024-02-19 18:35:42 -08:00
parent c203c57c18
commit 59032140b5

View File

@ -48,7 +48,7 @@ class LlamacppHF(PreTrainedModel):
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids, 'input_ids': self.model.input_ids,
'scores': self.model.scores, 'scores': self.model.scores,
'ctx': self.model._ctx 'ctx': self.model._ctx.ctx
} }
if shared.args.cfg_cache: if shared.args.cfg_cache:
@ -57,7 +57,7 @@ class LlamacppHF(PreTrainedModel):
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids.copy(), 'input_ids': self.model.input_ids.copy(),
'scores': self.model.scores.copy(), 'scores': self.model.scores.copy(),
'ctx': llama_cpp_lib()._internals._LlamaContext(model=model._model, params=model.context_params) 'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.context_params)
} }
def _validate_model_class(self): def _validate_model_class(self):
@ -74,7 +74,7 @@ class LlamacppHF(PreTrainedModel):
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids, 'input_ids': self.model.input_ids,
'scores': self.model.scores, 'scores': self.model.scores,
'ctx': self.model._ctx 'ctx': self.model._ctx.ctx
}) })
def save_negative_cache(self): def save_negative_cache(self):
@ -82,20 +82,20 @@ class LlamacppHF(PreTrainedModel):
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids, 'input_ids': self.model.input_ids,
'scores': self.model.scores, 'scores': self.model.scores,
'ctx': self.model._ctx 'ctx': self.model._ctx.ctx
}) })
def load_cache(self): def load_cache(self):
self.model.n_tokens = self.llamacpp_cache['n_tokens'] self.model.n_tokens = self.llamacpp_cache['n_tokens']
self.model.input_ids = self.llamacpp_cache['input_ids'] self.model.input_ids = self.llamacpp_cache['input_ids']
self.model.scores = self.llamacpp_cache['scores'] self.model.scores = self.llamacpp_cache['scores']
self.model._ctx = self.llamacpp_cache['ctx'] self.model._ctx.ctx = self.llamacpp_cache['ctx']
def load_negative_cache(self): def load_negative_cache(self):
self.model.n_tokens = self.llamacpp_cache_negative['n_tokens'] self.model.n_tokens = self.llamacpp_cache_negative['n_tokens']
self.model.input_ids = self.llamacpp_cache_negative['input_ids'] self.model.input_ids = self.llamacpp_cache_negative['input_ids']
self.model.scores = self.llamacpp_cache_negative['scores'] self.model.scores = self.llamacpp_cache_negative['scores']
self.model._ctx = self.llamacpp_cache_negative['ctx'] self.model._ctx.ctx = self.llamacpp_cache_negative['ctx']
@property @property
def device(self) -> torch.device: def device(self) -> torch.device: