mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Stop generating at \nYou: in chat mode
This commit is contained in:
parent
54e77acac4
commit
3b8f0021cc
@ -153,6 +153,6 @@ Pull requests, suggestions, and issue reports are welcome.
|
||||
## Credits
|
||||
|
||||
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
||||
- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
|
||||
- Pygmalion preset, code for early stopping in chat mode: https://github.com/PygmalionAI/gradio-ui/
|
||||
- Verbose preset: Anonymous 4chan user.
|
||||
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
|
27
server.py
27
server.py
@ -14,6 +14,7 @@ import transformers
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from modules.html_generator import *
|
||||
from modules.ui import *
|
||||
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -135,12 +136,12 @@ def fix_galactica(s):
|
||||
s = s.replace(r'$$', r'$')
|
||||
return s
|
||||
|
||||
def encode(prompt, tokens):
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
if args.cpu:
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
|
||||
return input_ids
|
||||
|
||||
def decode(output_ids):
|
||||
@ -161,7 +162,7 @@ def formatted_outputs(reply, model_name):
|
||||
else:
|
||||
return reply
|
||||
|
||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
|
||||
global model, tokenizer, model_name, loaded_preset, preset
|
||||
|
||||
if selected_model != model_name:
|
||||
@ -179,11 +180,22 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
cuda = "" if args.cpu else ".cuda()"
|
||||
n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||
input_ids = encode(question, tokens)
|
||||
# The stopping_criteria code below was copied from
|
||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||
if stopping_string is not None:
|
||||
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList([
|
||||
_SentinelTokenStoppingCriteria(
|
||||
sentinel_token_ids=t,
|
||||
starting_idx=len(input_ids[0]))
|
||||
])
|
||||
else:
|
||||
stopping_criteria_list = None
|
||||
|
||||
# Generate the entire reply at once
|
||||
if args.no_stream:
|
||||
t0 = time.time()
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||
reply = decode(output[0])
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
||||
@ -194,11 +206,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
yield formatted_outputs(question, model_name)
|
||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
||||
for i in tqdm(range(tokens)):
|
||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||
reply = decode(output[0])
|
||||
if eos_token is not None and reply[-1] == eos_token:
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, model_name)
|
||||
input_ids = output
|
||||
|
||||
@ -289,7 +300,7 @@ if args.chat or args.cai_chat:
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||
history.append(['', ''])
|
||||
eos_token = '\n' if check else None
|
||||
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
next_character_found = False
|
||||
|
||||
previous_idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", question)]
|
||||
|
Loading…
Reference in New Issue
Block a user