diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 4944c64c..9f6122d9 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -7,14 +7,20 @@ https://abetlen.github.io/llama-cpp-python/ ''' import re +from functools import partial -from llama_cpp import Llama, LlamaCache +from llama_cpp import Llama, LlamaCache, LogitsProcessorList from modules import shared from modules.callbacks import Iteratorize from modules.logging_colors import logger +def ban_eos_logits_processor(eos_token, input_ids, logits): + logits[eos_token] = -float('inf') + return logits + + class LlamaCppModel: def __init__(self): self.initialized = False @@ -72,7 +78,10 @@ class LlamaCppModel: mirostat_mode=int(state['mirostat_mode']), mirostat_tau=state['mirostat_tau'], mirostat_eta=state['mirostat_eta'], - stream=True + stream=True, + logits_processor=LogitsProcessorList([ + partial(ban_eos_logits_processor, self.model.token_eos()), + ]) if state['ban_eos_token'] else None, ) output = ""