mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix CFG with llamacpp_HF (2nd attempt)
This commit is contained in:
parent
c203c57c18
commit
59032140b5
@ -48,7 +48,7 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
'n_tokens': self.model.n_tokens,
|
'n_tokens': self.model.n_tokens,
|
||||||
'input_ids': self.model.input_ids,
|
'input_ids': self.model.input_ids,
|
||||||
'scores': self.model.scores,
|
'scores': self.model.scores,
|
||||||
'ctx': self.model._ctx
|
'ctx': self.model._ctx.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
if shared.args.cfg_cache:
|
if shared.args.cfg_cache:
|
||||||
@ -57,7 +57,7 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
'n_tokens': self.model.n_tokens,
|
'n_tokens': self.model.n_tokens,
|
||||||
'input_ids': self.model.input_ids.copy(),
|
'input_ids': self.model.input_ids.copy(),
|
||||||
'scores': self.model.scores.copy(),
|
'scores': self.model.scores.copy(),
|
||||||
'ctx': llama_cpp_lib()._internals._LlamaContext(model=model._model, params=model.context_params)
|
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.context_params)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _validate_model_class(self):
|
def _validate_model_class(self):
|
||||||
@ -74,7 +74,7 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
'n_tokens': self.model.n_tokens,
|
'n_tokens': self.model.n_tokens,
|
||||||
'input_ids': self.model.input_ids,
|
'input_ids': self.model.input_ids,
|
||||||
'scores': self.model.scores,
|
'scores': self.model.scores,
|
||||||
'ctx': self.model._ctx
|
'ctx': self.model._ctx.ctx
|
||||||
})
|
})
|
||||||
|
|
||||||
def save_negative_cache(self):
|
def save_negative_cache(self):
|
||||||
@ -82,20 +82,20 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
'n_tokens': self.model.n_tokens,
|
'n_tokens': self.model.n_tokens,
|
||||||
'input_ids': self.model.input_ids,
|
'input_ids': self.model.input_ids,
|
||||||
'scores': self.model.scores,
|
'scores': self.model.scores,
|
||||||
'ctx': self.model._ctx
|
'ctx': self.model._ctx.ctx
|
||||||
})
|
})
|
||||||
|
|
||||||
def load_cache(self):
|
def load_cache(self):
|
||||||
self.model.n_tokens = self.llamacpp_cache['n_tokens']
|
self.model.n_tokens = self.llamacpp_cache['n_tokens']
|
||||||
self.model.input_ids = self.llamacpp_cache['input_ids']
|
self.model.input_ids = self.llamacpp_cache['input_ids']
|
||||||
self.model.scores = self.llamacpp_cache['scores']
|
self.model.scores = self.llamacpp_cache['scores']
|
||||||
self.model._ctx = self.llamacpp_cache['ctx']
|
self.model._ctx.ctx = self.llamacpp_cache['ctx']
|
||||||
|
|
||||||
def load_negative_cache(self):
|
def load_negative_cache(self):
|
||||||
self.model.n_tokens = self.llamacpp_cache_negative['n_tokens']
|
self.model.n_tokens = self.llamacpp_cache_negative['n_tokens']
|
||||||
self.model.input_ids = self.llamacpp_cache_negative['input_ids']
|
self.model.input_ids = self.llamacpp_cache_negative['input_ids']
|
||||||
self.model.scores = self.llamacpp_cache_negative['scores']
|
self.model.scores = self.llamacpp_cache_negative['scores']
|
||||||
self.model._ctx = self.llamacpp_cache_negative['ctx']
|
self.model._ctx.ctx = self.llamacpp_cache_negative['ctx']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
|
Loading…
Reference in New Issue
Block a user