mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
llama.cpp: implement ban_eos_token via logits_processor (#2765)
This commit is contained in:
parent
0d9d70ec7e
commit
59e7ecb198
@ -7,14 +7,20 @@ https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from llama_cpp import Llama, LlamaCache
|
||||
from llama_cpp import Llama, LlamaCache, LogitsProcessorList
|
||||
|
||||
from modules import shared
|
||||
from modules.callbacks import Iteratorize
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
||||
logits[eos_token] = -float('inf')
|
||||
return logits
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
@ -72,7 +78,10 @@ class LlamaCppModel:
|
||||
mirostat_mode=int(state['mirostat_mode']),
|
||||
mirostat_tau=state['mirostat_tau'],
|
||||
mirostat_eta=state['mirostat_eta'],
|
||||
stream=True
|
||||
stream=True,
|
||||
logits_processor=LogitsProcessorList([
|
||||
partial(ban_eos_logits_processor, self.model.token_eos()),
|
||||
]) if state['ban_eos_token'] else None,
|
||||
)
|
||||
|
||||
output = ""
|
||||
|
Loading…
Reference in New Issue
Block a user