mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 19:09:32 +01:00
Stop generating at \nYou: in chat mode
This commit is contained in:
parent
54e77acac4
commit
3b8f0021cc
@ -153,6 +153,6 @@ Pull requests, suggestions, and issue reports are welcome.
|
|||||||
## Credits
|
## Credits
|
||||||
|
|
||||||
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
||||||
- Pygmalion preset: https://github.com/PygmalionAI/gradio-ui/blob/master/src/gradio_ui.py
|
- Pygmalion preset, code for early stopping in chat mode: https://github.com/PygmalionAI/gradio-ui/
|
||||||
- Verbose preset: Anonymous 4chan user.
|
- Verbose preset: Anonymous 4chan user.
|
||||||
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||||
|
27
server.py
27
server.py
@ -14,6 +14,7 @@ import transformers
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from modules.html_generator import *
|
from modules.html_generator import *
|
||||||
from modules.ui import *
|
from modules.ui import *
|
||||||
|
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -135,12 +136,12 @@ def fix_galactica(s):
|
|||||||
s = s.replace(r'$$', r'$')
|
s = s.replace(r'$$', r'$')
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def encode(prompt, tokens):
|
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
|
||||||
else:
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def decode(output_ids):
|
def decode(output_ids):
|
||||||
@ -161,7 +162,7 @@ def formatted_outputs(reply, model_name):
|
|||||||
else:
|
else:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
|
||||||
global model, tokenizer, model_name, loaded_preset, preset
|
global model, tokenizer, model_name, loaded_preset, preset
|
||||||
|
|
||||||
if selected_model != model_name:
|
if selected_model != model_name:
|
||||||
@ -179,11 +180,22 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
cuda = "" if args.cpu else ".cuda()"
|
cuda = "" if args.cpu else ".cuda()"
|
||||||
n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||||
input_ids = encode(question, tokens)
|
input_ids = encode(question, tokens)
|
||||||
|
# The stopping_criteria code below was copied from
|
||||||
|
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||||
|
if stopping_string is not None:
|
||||||
|
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList([
|
||||||
|
_SentinelTokenStoppingCriteria(
|
||||||
|
sentinel_token_ids=t,
|
||||||
|
starting_idx=len(input_ids[0]))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
stopping_criteria_list = None
|
||||||
|
|
||||||
# Generate the entire reply at once
|
# Generate the entire reply at once
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||||
reply = decode(output[0])
|
reply = decode(output[0])
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
|
||||||
@ -194,11 +206,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
yield formatted_outputs(question, model_name)
|
yield formatted_outputs(question, model_name)
|
||||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
||||||
for i in tqdm(range(tokens)):
|
for i in tqdm(range(tokens)):
|
||||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
|
||||||
reply = decode(output[0])
|
reply = decode(output[0])
|
||||||
if eos_token is not None and reply[-1] == eos_token:
|
if eos_token is not None and reply[-1] == eos_token:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield formatted_outputs(reply, model_name)
|
yield formatted_outputs(reply, model_name)
|
||||||
input_ids = output
|
input_ids = output
|
||||||
|
|
||||||
@ -289,7 +300,7 @@ if args.chat or args.cai_chat:
|
|||||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||||
history.append(['', ''])
|
history.append(['', ''])
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||||
next_character_found = False
|
next_character_found = False
|
||||||
|
|
||||||
previous_idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", question)]
|
previous_idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", question)]
|
||||||
|
Loading…
Reference in New Issue
Block a user