mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Add skip_special_tokens checkbox for Dolly model (#1218)
This commit is contained in:
parent
a9c7ef4159
commit
b937c9d8c2
@ -42,9 +42,10 @@ async def run(context):
|
||||
'early_stopping': False,
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'custom_stopping_strings': [],
|
||||
'ban_eos_token': False
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
}
|
||||
payload = json.dumps([context, params])
|
||||
session = random_hash()
|
||||
|
@ -39,6 +39,7 @@ params = {
|
||||
'custom_stopping_strings': [],
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
}
|
||||
|
||||
# Input prompt
|
||||
|
@ -61,6 +61,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
'custom_stopping_strings': body.get('custom_stopping_strings', []),
|
||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
}
|
||||
|
||||
generator = generate_reply(
|
||||
|
@ -4,6 +4,8 @@
|
||||
groupsize: 'None'
|
||||
pre_layer: 0
|
||||
mode: 'cai-chat'
|
||||
skip_special_tokens: true
|
||||
custom_stopping_strings: ''
|
||||
llama-[0-9]*b-4bit$:
|
||||
wbits: 4
|
||||
model_type: 'llama'
|
||||
@ -33,3 +35,10 @@ llama-[0-9]*b-4bit$:
|
||||
instruction_template: 'Alpaca'
|
||||
wbits: 4
|
||||
groupsize: 128
|
||||
.*(galactica|oasst):
|
||||
skip_special_tokens: false
|
||||
.*dolly-v[0-9]-[0-9]*b:
|
||||
mode: 'instruct'
|
||||
instruction_template: 'Alpaca'
|
||||
skip_special_tokens: false
|
||||
custom_stopping_strings: '"### End"'
|
||||
|
@ -41,6 +41,7 @@ settings = {
|
||||
'stop_at_newline': False,
|
||||
'add_bos_token': True,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'truncation_length': 2048,
|
||||
'truncation_length_min': 0,
|
||||
'truncation_length_max': 4096,
|
||||
|
@ -57,14 +57,13 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
else:
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
if skip_special_tokens:
|
||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
else:
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
@ -184,7 +183,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
output = input_ids[0]
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
|
||||
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
|
||||
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
@ -231,11 +230,12 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
@ -256,18 +256,20 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
@ -276,18 +278,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
|
@ -25,7 +25,7 @@ def list_model_elements():
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
|
||||
elements += list_model_elements()
|
||||
|
@ -424,7 +424,9 @@ def create_settings_menus(default_preset):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='This forces the model to never end the generation prematurely.')
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||
|
||||
@ -766,7 +768,7 @@ def create_interface():
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['instruction_template'].change(
|
||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||
chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['upload_chat_history'].upload(
|
||||
@ -784,6 +786,7 @@ def create_interface():
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
"stop_at_newline": false,
|
||||
"add_bos_token": true,
|
||||
"ban_eos_token": false,
|
||||
"skip_special_tokens": true,
|
||||
"truncation_length": 2048,
|
||||
"truncation_length_min": 0,
|
||||
"truncation_length_max": 4096,
|
||||
|
Loading…
Reference in New Issue
Block a user