From 59e7ecb1985a36f8a9ac6a40236b22533efc0b56 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Mon, 19 Jun 2023 20:31:19 -0400 Subject: [PATCH] llama.cpp: implement ban_eos_token via logits_processor (#2765) --- modules/llamacpp_model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 4944c64c..9f6122d9 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -7,14 +7,20 @@ https://abetlen.github.io/llama-cpp-python/ ''' import re +from functools import partial -from llama_cpp import Llama, LlamaCache +from llama_cpp import Llama, LlamaCache, LogitsProcessorList from modules import shared from modules.callbacks import Iteratorize from modules.logging_colors import logger +def ban_eos_logits_processor(eos_token, input_ids, logits): + logits[eos_token] = -float('inf') + return logits + + class LlamaCppModel: def __init__(self): self.initialized = False @@ -72,7 +78,10 @@ class LlamaCppModel: mirostat_mode=int(state['mirostat_mode']), mirostat_tau=state['mirostat_tau'], mirostat_eta=state['mirostat_eta'], - stream=True + stream=True, + logits_processor=LogitsProcessorList([ + partial(ban_eos_logits_processor, self.model.token_eos()), + ]) if state['ban_eos_token'] else None, ) output = ""