Add CFG to llamacpp_HF (second attempt) (#3678)

This commit is contained in:
oobabooga 2023-08-24 20:32:21 -03:00 committed by GitHub
parent d6934bc7bc
commit 3320accfdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 6 deletions

View File

@ -280,6 +280,7 @@ Optionally, you can use the following command-line flags:
| `--n_gqa N_GQA` | grouped-query attention. Must be 8 for llama-2 70b. | | `--n_gqa N_GQA` | grouped-query attention. Must be 8 for llama-2 70b. |
| `--rms_norm_eps RMS_NORM_EPS` | 5e-6 is a good value for llama-2 models. | | `--rms_norm_eps RMS_NORM_EPS` | 5e-6 is a good value for llama-2 models. |
| `--cpu` | Use the CPU version of llama-cpp-python instead of the GPU-accelerated version. | | `--cpu` | Use the CPU version of llama-cpp-python instead of the GPU-accelerated version. |
|`--cfg-cache` | llamacpp_HF: Create an additional cache for CFG negative prompts. |
#### ctransformers #### ctransformers

View File

@ -38,16 +38,17 @@ class LlamacppHF(PreTrainedModel):
self.llamacpp_cache = { self.llamacpp_cache = {
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids, 'input_ids': self.model.input_ids,
'scores': self.model.scores 'scores': self.model.scores,
'ctx': self.model.ctx
} }
if shared.args.cfg_cache: if shared.args.cfg_cache:
logger.warning('CFG is currently bugged and not functional for llamacpp_HF. Contributions are welcome.')
self.past_seq_negative = None self.past_seq_negative = None
self.llamacpp_cache_negative = { self.llamacpp_cache_negative = {
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids.copy(), 'input_ids': self.model.input_ids.copy(),
'scores': self.model.scores.copy() 'scores': self.model.scores.copy(),
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params)
} }
def _validate_model_class(self): def _validate_model_class(self):
@ -63,25 +64,29 @@ class LlamacppHF(PreTrainedModel):
self.llamacpp_cache.update({ self.llamacpp_cache.update({
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids, 'input_ids': self.model.input_ids,
'scores': self.model.scores 'scores': self.model.scores,
'ctx': self.model.ctx
}) })
def save_negative_cache(self): def save_negative_cache(self):
self.llamacpp_cache_negative.update({ self.llamacpp_cache_negative.update({
'n_tokens': self.model.n_tokens, 'n_tokens': self.model.n_tokens,
'input_ids': self.model.input_ids, 'input_ids': self.model.input_ids,
'scores': self.model.scores 'scores': self.model.scores,
'ctx': self.model.ctx
}) })
def load_cache(self): def load_cache(self):
self.model.n_tokens = self.llamacpp_cache['n_tokens'] self.model.n_tokens = self.llamacpp_cache['n_tokens']
self.model.input_ids = self.llamacpp_cache['input_ids'] self.model.input_ids = self.llamacpp_cache['input_ids']
self.model.scores = self.llamacpp_cache['scores'] self.model.scores = self.llamacpp_cache['scores']
self.model.ctx = self.llamacpp_cache['ctx']
def load_negative_cache(self): def load_negative_cache(self):
self.model.n_tokens = self.llamacpp_cache_negative['n_tokens'] self.model.n_tokens = self.llamacpp_cache_negative['n_tokens']
self.model.input_ids = self.llamacpp_cache_negative['input_ids'] self.model.input_ids = self.llamacpp_cache_negative['input_ids']
self.model.scores = self.llamacpp_cache_negative['scores'] self.model.scores = self.llamacpp_cache_negative['scores']
self.model.ctx = self.llamacpp_cache_negative['ctx']
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
@ -95,7 +100,6 @@ class LlamacppHF(PreTrainedModel):
if len(args) > 0: if len(args) > 0:
if not shared.args.cfg_cache: if not shared.args.cfg_cache:
logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.") logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.")
logger.warning('CFG is currently bugged and not functional for llamacpp_HF. Contributions are welcome.')
return return
input_ids = args[0] input_ids = args[0]

View File

@ -95,6 +95,7 @@ loaders_and_params = OrderedDict({
'alpha_value', 'alpha_value',
'compress_pos_emb', 'compress_pos_emb',
'cpu', 'cpu',
'cfg_cache',
'llamacpp_HF_info', 'llamacpp_HF_info',
], ],
'ctransformers': [ 'ctransformers': [
@ -268,6 +269,8 @@ loaders_samplers = {
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'guidance_scale',
'negative_prompt',
'ban_eos_token', 'ban_eos_token',
'add_bos_token', 'add_bos_token',
'skip_special_tokens', 'skip_special_tokens',