diff --git a/api-example-stream.py b/api-example-stream.py index 7a87114b..a046fabd 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -42,9 +42,10 @@ async def run(context): 'early_stopping': False, 'seed': -1, 'add_bos_token': True, - 'truncation_length': 2048, 'custom_stopping_strings': [], - 'ban_eos_token': False + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, } payload = json.dumps([context, params]) session = random_hash() diff --git a/api-example.py b/api-example.py index 60b95d75..5138eb88 100644 --- a/api-example.py +++ b/api-example.py @@ -39,6 +39,7 @@ params = { 'custom_stopping_strings': [], 'truncation_length': 2048, 'ban_eos_token': False, + 'skip_special_tokens': True, } # Input prompt diff --git a/extensions/api/script.py b/extensions/api/script.py index 458f60e7..17878c5d 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -61,6 +61,7 @@ class Handler(BaseHTTPRequestHandler): 'custom_stopping_strings': body.get('custom_stopping_strings', []), 'truncation_length': int(body.get('truncation_length', 2048)), 'ban_eos_token': bool(body.get('ban_eos_token', False)), + 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), } generator = generate_reply( diff --git a/models/config.yaml b/models/config.yaml index 843e1a9e..7701ecc4 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -4,6 +4,8 @@ groupsize: 'None' pre_layer: 0 mode: 'cai-chat' + skip_special_tokens: true + custom_stopping_strings: '' llama-[0-9]*b-4bit$: wbits: 4 model_type: 'llama' @@ -33,3 +35,10 @@ llama-[0-9]*b-4bit$: instruction_template: 'Alpaca' wbits: 4 groupsize: 128 +.*(galactica|oasst): + skip_special_tokens: false +.*dolly-v[0-9]-[0-9]*b: + mode: 'instruct' + instruction_template: 'Alpaca' + skip_special_tokens: false + custom_stopping_strings: '"### End"' diff --git a/modules/shared.py b/modules/shared.py index 374942e9..4a113502 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -41,6 +41,7 @@ settings = { 'stop_at_newline': False, 'add_bos_token': True, 'ban_eos_token': False, + 'skip_special_tokens': True, 'truncation_length': 2048, 'truncation_length_min': 0, 'truncation_length_max': 4096, diff --git a/modules/text_generation.py b/modules/text_generation.py index 7bfaafe0..6313785c 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -57,14 +57,13 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt return input_ids.cuda() -def decode(output_ids): - # Open Assistant relies on special tokens like <|endoftext|> - if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): - return shared.tokenizer.decode(output_ids, skip_special_tokens=False) - else: +def decode(output_ids, skip_special_tokens=True): + if skip_special_tokens: reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) reply = reply.replace(r'<|endoftext|>', '') return reply + else: + return shared.tokenizer.decode(output_ids, skip_special_tokens=False) def generate_softprompt_input_tensors(input_ids): @@ -184,7 +183,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): output = input_ids[0] if shared.args.verbose: - print(f'\n\n{decode(input_ids[0])}\n--------------------\n') + print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n') cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] @@ -231,11 +230,12 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): output = shared.model.generate(**generate_params)[0] if cuda: output = output.cuda() + if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) new_tokens = len(output) - len(input_ids[0]) - reply = decode(output[-new_tokens:]) + reply = decode(output[-new_tokens:], state['skip_special_tokens']) if not shared.is_chat(): reply = original_question + apply_extensions(reply, 'output') @@ -256,18 +256,20 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): if not shared.is_chat(): yield formatted_outputs(original_question, shared.model_name) + with generate_with_streaming(**generate_params) as generator: for output in generator: if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) new_tokens = len(output) - len(input_ids[0]) - reply = decode(output[-new_tokens:]) + reply = decode(output[-new_tokens:], state['skip_special_tokens']) if not shared.is_chat(): reply = original_question + apply_extensions(reply, 'output') if output[-1] in eos_token_ids: break + yield formatted_outputs(reply, shared.model_name) # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' @@ -276,18 +278,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): clear_torch_cache() with torch.no_grad(): output = shared.model.generate(**generate_params)[0] + if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) new_tokens = len(output) - len(original_input_ids[0]) - reply = decode(output[-new_tokens:]) + reply = decode(output[-new_tokens:], state['skip_special_tokens']) if not shared.is_chat(): reply = original_question + apply_extensions(reply, 'output') if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): break - yield formatted_outputs(reply, shared.model_name) + yield formatted_outputs(reply, shared.model_name) input_ids = np.reshape(output, (1, output.shape[0])) if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) diff --git a/modules/ui.py b/modules/ui.py index beeac8f5..df4c24f1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -25,7 +25,7 @@ def list_model_elements(): def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens'] if chat: elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template'] elements += list_model_elements() diff --git a/server.py b/server.py index 9bbc4237..84b53ce4 100644 --- a/server.py +++ b/server.py @@ -424,7 +424,9 @@ def create_settings_menus(default_preset): with gr.Group(): with gr.Row(): shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') - shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='This forces the model to never end the generation prematurely.') + shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') + + shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') @@ -766,7 +768,7 @@ def create_interface(): chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['instruction_template'].change( - lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then( + chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then( chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['upload_chat_history'].upload( @@ -784,6 +786,7 @@ def create_interface(): shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display']) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") + shared.gradio['interface'].load(chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None) shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True) diff --git a/settings-template.json b/settings-template.json index 244e0a40..bfcd7a91 100644 --- a/settings-template.json +++ b/settings-template.json @@ -12,6 +12,7 @@ "stop_at_newline": false, "add_bos_token": true, "ban_eos_token": false, + "skip_special_tokens": true, "truncation_length": 2048, "truncation_length_min": 0, "truncation_length_max": 4096,