Merge remote-tracking branch 'refs/remotes/origin/main'

This commit is contained in:
oobabooga 2023-06-20 00:46:29 -03:00
commit 017884132f

View File

@ -7,14 +7,20 @@ https://abetlen.github.io/llama-cpp-python/
'''
import re
from functools import partial
from llama_cpp import Llama, LlamaCache
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
from modules import shared
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float('inf')
return logits
class LlamaCppModel:
def __init__(self):
self.initialized = False
@ -72,7 +78,10 @@ class LlamaCppModel:
mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'],
stream=True
stream=True,
logits_processor=LogitsProcessorList([
partial(ban_eos_logits_processor, self.model.token_eos()),
]) if state['ban_eos_token'] else None,
)
output = ""