llama.cpp: implement ban_eos_token via logits_processor (#2765)

This commit is contained in:
Cebtenzzre 2023-06-19 20:31:19 -04:00 committed by GitHub
parent 0d9d70ec7e
commit 59e7ecb198
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,14 +7,20 @@ https://abetlen.github.io/llama-cpp-python/
''' '''
import re import re
from functools import partial
from llama_cpp import Llama, LlamaCache from llama_cpp import Llama, LlamaCache, LogitsProcessorList
from modules import shared from modules import shared
from modules.callbacks import Iteratorize from modules.callbacks import Iteratorize
from modules.logging_colors import logger from modules.logging_colors import logger
def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float('inf')
return logits
class LlamaCppModel: class LlamaCppModel:
def __init__(self): def __init__(self):
self.initialized = False self.initialized = False
@ -72,7 +78,10 @@ class LlamaCppModel:
mirostat_mode=int(state['mirostat_mode']), mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=state['mirostat_tau'], mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'], mirostat_eta=state['mirostat_eta'],
stream=True stream=True,
logits_processor=LogitsProcessorList([
partial(ban_eos_logits_processor, self.model.token_eos()),
]) if state['ban_eos_token'] else None,
) )
output = "" output = ""