mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Implement stopping string search in string space (#2847)
This commit is contained in:
parent
0f9088f730
commit
8bb3bb39b3
@ -350,4 +350,4 @@ The presets that are included by default are the result of a contest that receiv
|
|||||||
|
|
||||||
- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||||
- Godlike preset: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
- Godlike preset: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
||||||
- Code for early stopping in chat mode, code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/
|
- Code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/
|
||||||
|
@ -9,33 +9,6 @@ import transformers
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|
||||||
|
|
||||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
|
||||||
transformers.StoppingCriteria.__init__(self)
|
|
||||||
self.sentinel_token_ids = sentinel_token_ids
|
|
||||||
self.starting_idx = starting_idx
|
|
||||||
self.shortest = min([x.shape[-1] for x in sentinel_token_ids])
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
|
||||||
for sample in input_ids:
|
|
||||||
trimmed_sample = sample[self.starting_idx:]
|
|
||||||
trimmed_len = trimmed_sample.shape[-1]
|
|
||||||
if trimmed_len < self.shortest:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for sentinel in self.sentinel_token_ids:
|
|
||||||
sentinel_len = sentinel.shape[-1]
|
|
||||||
if trimmed_len < sentinel_len:
|
|
||||||
continue
|
|
||||||
|
|
||||||
window = trimmed_sample[-sentinel_len:]
|
|
||||||
if torch.all(torch.eq(sentinel, window)):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
def __init__(self, callback_func=None):
|
def __init__(self, callback_func=None):
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import ast
|
|
||||||
import base64
|
import base64
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
@ -144,40 +143,10 @@ def get_stopping_strings(state):
|
|||||||
f"\n{state['name2']}:"
|
f"\n{state['name2']}:"
|
||||||
]
|
]
|
||||||
|
|
||||||
stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]")
|
|
||||||
return stopping_strings
|
|
||||||
|
|
||||||
|
|
||||||
def extract_message_from_reply(reply, state):
|
|
||||||
next_character_found = False
|
|
||||||
stopping_strings = get_stopping_strings(state)
|
|
||||||
|
|
||||||
if state['stop_at_newline']:
|
if state['stop_at_newline']:
|
||||||
lines = reply.split('\n')
|
stopping_strings.append("\n")
|
||||||
reply = lines[0].strip()
|
|
||||||
if len(lines) > 1:
|
|
||||||
next_character_found = True
|
|
||||||
else:
|
|
||||||
for string in stopping_strings:
|
|
||||||
idx = reply.find(string)
|
|
||||||
if idx != -1:
|
|
||||||
reply = reply[:idx]
|
|
||||||
next_character_found = True
|
|
||||||
|
|
||||||
# If something like "\nYo" is generated just before "\nYou:"
|
return stopping_strings
|
||||||
# is completed, trim it
|
|
||||||
if not next_character_found:
|
|
||||||
for string in stopping_strings:
|
|
||||||
for j in range(len(string) - 1, 0, -1):
|
|
||||||
if reply[-j:] == string[:j]:
|
|
||||||
reply = reply[:-j]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
return reply, next_character_found
|
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||||
@ -191,7 +160,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
# Defining some variables
|
# Defining some variables
|
||||||
just_started = True
|
just_started = True
|
||||||
visible_text = None
|
visible_text = None
|
||||||
eos_token = '\n' if state['stop_at_newline'] else None
|
|
||||||
stopping_strings = get_stopping_strings(state)
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
# Preparing the input
|
# Preparing the input
|
||||||
@ -231,11 +199,10 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
for i in range(state['chat_generation_attempts']):
|
for i in range(state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True)):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
|
|
||||||
# Extract the reply
|
# Extract the reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
visible_reply = apply_extensions("output", visible_reply)
|
visible_reply = apply_extensions("output", visible_reply)
|
||||||
|
|
||||||
@ -262,9 +229,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
if state['stream']:
|
if state['stream']:
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
if next_character_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
if reply in [None, cumulative_reply]:
|
if reply in [None, cumulative_reply]:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -281,7 +245,6 @@ def impersonate_wrapper(text, start_with, state):
|
|||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
eos_token = '\n' if state['stop_at_newline'] else None
|
|
||||||
prompt = generate_chat_prompt('', state, impersonate=True)
|
prompt = generate_chat_prompt('', state, impersonate=True)
|
||||||
stopping_strings = get_stopping_strings(state)
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
@ -289,16 +252,12 @@ def impersonate_wrapper(text, start_with, state):
|
|||||||
cumulative_reply = text
|
cumulative_reply = text
|
||||||
for i in range(state['chat_generation_attempts']):
|
for i in range(state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True):
|
for reply in generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
|
||||||
yield reply.lstrip(' ')
|
yield reply.lstrip(' ')
|
||||||
if shared.stop_everything:
|
if shared.stop_everything:
|
||||||
return
|
return
|
||||||
|
|
||||||
if next_character_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
if reply in [None, cumulative_reply]:
|
if reply in [None, cumulative_reply]:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -9,8 +9,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.callbacks import (Iteratorize, Stream,
|
from modules.callbacks import Iteratorize, Stream
|
||||||
_SentinelTokenStoppingCriteria)
|
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
@ -42,11 +41,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||||||
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||||
input_ids = input_ids[:, 1:]
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
# Llama adds this extra token when the first character is '\n', and this
|
|
||||||
# compromises the stopping criteria, so we just remove it
|
|
||||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
|
||||||
input_ids = input_ids[:, 1:]
|
|
||||||
|
|
||||||
# Handling truncation
|
# Handling truncation
|
||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
input_ids = input_ids[:, -truncation_length:]
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
@ -139,15 +133,43 @@ def stop_everything_event():
|
|||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
|
|
||||||
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
|
def generate_reply_wrapper(question, state, stopping_strings=None):
|
||||||
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
|
reply = question if not shared.is_seq2seq else ''
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
||||||
if not shared.is_seq2seq:
|
if not shared.is_seq2seq:
|
||||||
reply = question + reply
|
reply = question + reply
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
|
def apply_stopping_strings(reply, all_stop_strings):
|
||||||
|
stop_found = False
|
||||||
|
for string in all_stop_strings:
|
||||||
|
idx = reply.find(string)
|
||||||
|
if idx != -1:
|
||||||
|
reply = reply[:idx]
|
||||||
|
stop_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not stop_found:
|
||||||
|
# If something like "\nYo" is generated just before "\nYou:"
|
||||||
|
# is completed, trim it
|
||||||
|
for string in all_stop_strings:
|
||||||
|
for j in range(len(string) - 1, 0, -1):
|
||||||
|
if reply[-j:] == string[:j]:
|
||||||
|
reply = reply[:-j]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
return reply, stop_found
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||||
state = apply_extensions('state', state)
|
state = apply_extensions('state', state)
|
||||||
generate_func = apply_extensions('custom_generate_reply')
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
if generate_func is None:
|
if generate_func is None:
|
||||||
@ -168,29 +190,39 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
|||||||
if not is_chat:
|
if not is_chat:
|
||||||
question = apply_extensions('input', question)
|
question = apply_extensions('input', question)
|
||||||
|
|
||||||
|
# Finding the stopping strings
|
||||||
|
all_stop_strings = []
|
||||||
|
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||||
|
if type(st) is list and len(st) > 0:
|
||||||
|
all_stop_strings += st
|
||||||
|
|
||||||
if shared.args.verbose:
|
if shared.args.verbose:
|
||||||
print(f'\n\n{question}\n--------------------\n')
|
print(f'\n\n{question}\n--------------------\n')
|
||||||
|
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
seed = set_manual_seed(state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
is_stream = state['stream']
|
|
||||||
last_update = -1
|
last_update = -1
|
||||||
reply = ''
|
reply = ''
|
||||||
for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings, is_chat=is_chat):
|
is_stream = state['stream']
|
||||||
|
if len(all_stop_strings) > 0 and not state['stream']:
|
||||||
|
state['stream'] = True
|
||||||
|
|
||||||
|
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||||
|
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||||
if is_stream:
|
if is_stream:
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||||
last_update = cur_time
|
last_update = cur_time
|
||||||
yield reply
|
yield reply
|
||||||
else:
|
|
||||||
yield reply
|
|
||||||
|
|
||||||
if is_stream:
|
if stop_found:
|
||||||
|
break
|
||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
@ -213,11 +245,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||||
|
|
||||||
# Find the eos tokens
|
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
|
||||||
if eos_token is not None:
|
|
||||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
|
||||||
|
|
||||||
# Add the encoded tokens to generate_params
|
# Add the encoded tokens to generate_params
|
||||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
@ -225,17 +252,10 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
|
# Find the eos tokens
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
|
||||||
if type(st) is list and len(st) > 0:
|
|
||||||
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
|
||||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
|
||||||
break
|
|
||||||
|
|
||||||
# Update generate_params with the eos token and the stopping strings
|
|
||||||
generate_params['eos_token_id'] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
@ -280,7 +300,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
seed = set_manual_seed(state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@ -312,7 +332,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
def generate_reply_flexgen(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
@ -326,8 +346,8 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N
|
|||||||
|
|
||||||
# Find the eos tokens
|
# Find the eos tokens
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
if eos_token is not None:
|
if not state['ban_eos_token']:
|
||||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
generate_params['stop'] = eos_token_ids[-1]
|
||||||
|
|
||||||
# Add the encoded tokens to generate_params
|
# Add the encoded tokens to generate_params
|
||||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||||
@ -336,9 +356,6 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N
|
|||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
# Update generate_params with the eos token and the stopping strings
|
|
||||||
generate_params['stop'] = eos_token_ids[-1]
|
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
|
Loading…
Reference in New Issue
Block a user