From 59032140b54a2dff9aa314d6720abf42c5651b6a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 19 Feb 2024 18:35:42 -0800 Subject: [PATCH] Fix CFG with llamacpp_HF (2nd attempt) --- modules/llamacpp_hf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index e1bdd208..e5a05f6e 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -48,7 +48,7 @@ class LlamacppHF(PreTrainedModel): 'n_tokens': self.model.n_tokens, 'input_ids': self.model.input_ids, 'scores': self.model.scores, - 'ctx': self.model._ctx + 'ctx': self.model._ctx.ctx } if shared.args.cfg_cache: @@ -57,7 +57,7 @@ class LlamacppHF(PreTrainedModel): 'n_tokens': self.model.n_tokens, 'input_ids': self.model.input_ids.copy(), 'scores': self.model.scores.copy(), - 'ctx': llama_cpp_lib()._internals._LlamaContext(model=model._model, params=model.context_params) + 'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.context_params) } def _validate_model_class(self): @@ -74,7 +74,7 @@ class LlamacppHF(PreTrainedModel): 'n_tokens': self.model.n_tokens, 'input_ids': self.model.input_ids, 'scores': self.model.scores, - 'ctx': self.model._ctx + 'ctx': self.model._ctx.ctx }) def save_negative_cache(self): @@ -82,20 +82,20 @@ class LlamacppHF(PreTrainedModel): 'n_tokens': self.model.n_tokens, 'input_ids': self.model.input_ids, 'scores': self.model.scores, - 'ctx': self.model._ctx + 'ctx': self.model._ctx.ctx }) def load_cache(self): self.model.n_tokens = self.llamacpp_cache['n_tokens'] self.model.input_ids = self.llamacpp_cache['input_ids'] self.model.scores = self.llamacpp_cache['scores'] - self.model._ctx = self.llamacpp_cache['ctx'] + self.model._ctx.ctx = self.llamacpp_cache['ctx'] def load_negative_cache(self): self.model.n_tokens = self.llamacpp_cache_negative['n_tokens'] self.model.input_ids = self.llamacpp_cache_negative['input_ids'] self.model.scores = self.llamacpp_cache_negative['scores'] - self.model._ctx = self.llamacpp_cache_negative['ctx'] + self.model._ctx.ctx = self.llamacpp_cache_negative['ctx'] @property def device(self) -> torch.device: