diff --git a/modules/chat.py b/modules/chat.py index 0ef61f8c..428e8430 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -18,35 +18,35 @@ from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs): - is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False - end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' +def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False _continue = kwargs['_continue'] if '_continue' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False - rows = [f"{context.strip()}\n"] + is_instruct = state['mode'] == 'instruct' + rows = [f"{state['context'].strip()}\n"] # Finding the maximum prompt size + chat_prompt_size = state['chat_prompt_size'] if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] - max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) + max_length = min(get_max_prompt_length(state), chat_prompt_size) if is_instruct: - prefix1 = f"{name1}\n" - prefix2 = f"{name2}\n" + prefix1 = f"{state['name1']}\n" + prefix2 = f"{state['name2']}\n" else: - prefix1 = f"{name1}: " - prefix2 = f"{name2}: " + prefix1 = f"{state['name1']}: " + prefix2 = f"{state['name2']}: " i = len(shared.history['internal']) - 1 - while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: + while i >= 0 and len(encode(''.join(rows))[0]) < max_length: if _continue and i == len(shared.history['internal']) - 1: rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") else: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n") string = shared.history['internal'][i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: - rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") + rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n") i -= 1 if impersonate: @@ -58,13 +58,13 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat # Adding the user message user_input = fix_newlines(user_input) if len(user_input) > 0: - rows.append(f"{prefix1}{user_input}{end_of_turn}\n") + rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n") # Adding the Character prefix rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) limit = 3 - while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: + while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length: rows.pop(1) prompt = ''.join(rows) @@ -139,15 +139,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): text = apply_extensions(text, "input") # Generating the prompt - kwargs = { - 'end_of_turn': state['end_of_turn'], - 'is_instruct': state['mode'] == 'instruct', - '_continue': _continue - } if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs) + prompt = generate_chat_prompt(text, state) else: - prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs) + prompt = custom_generate_chat_prompt(text, state) # Yield *Is typing...* if not any((regenerate, _continue)): @@ -197,7 +192,7 @@ def impersonate_wrapper(text, state): # Defining some variables cumulative_reply = '' eos_token = '\n' if state['stop_at_newline'] else None - prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True) + prompt = generate_chat_prompt(text, state, impersonate=True) stopping_strings = get_stopping_strings(state) # Yield *Is typing...* diff --git a/modules/models.py b/modules/models.py index 7ec93df8..3467f4f2 100644 --- a/modules/models.py +++ b/modules/models.py @@ -189,7 +189,6 @@ def load_model(model_name): pass else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) - tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer diff --git a/modules/shared.py b/modules/shared.py index b278e2fd..a47a13f1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -37,6 +37,10 @@ settings = { 'custom_stopping_strings': '', 'stop_at_newline': False, 'add_bos_token': True, + 'ban_eos_token': False, + 'truncation_length': 2048, + 'truncation_length_min': 0, + 'truncation_length_max': 4096, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, 'chat_prompt_size_max': 2048, diff --git a/modules/text_generation.py b/modules/text_generation.py index 7d62dddb..15b88264 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -15,20 +15,20 @@ from modules.html_generator import generate_4chan_html, generate_basic_html from modules.models import clear_torch_cache, local_rank -def get_max_prompt_length(tokens): - max_length = 2048 - tokens +def get_max_prompt_length(state): + max_length = state['truncation_length'] - state['max_new_tokens'] if shared.soft_prompt: max_length -= shared.soft_prompt_tensor.shape[1] return max_length -def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True): +def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): if any((shared.is_RWKV, shared.is_llamacpp)): input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) return input_ids else: - input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) + input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) # This is a hack for making replies more creative. if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id: @@ -39,17 +39,21 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token= if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: input_ids = input_ids[:, 1:] - if shared.args.cpu: - return input_ids - elif shared.args.flexgen: - return input_ids.numpy() - elif shared.args.deepspeed: - return input_ids.to(device=local_rank) - elif torch.has_mps: - device = torch.device('mps') - return input_ids.to(device) - else: - return input_ids.cuda() + # Handling truncation + if truncation_length is not None: + input_ids = input_ids[:, -truncation_length:] + + if any((shared.is_RWKV, shared.is_llamacpp, shared.args.cpu)): + return input_ids + elif shared.args.flexgen: + return input_ids.numpy() + elif shared.args.deepspeed: + return input_ids.to(device=local_rank) + elif torch.has_mps: + device = torch.device('mps') + return input_ids.to(device) + else: + return input_ids.cuda() def decode(output_ids): @@ -129,12 +133,14 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): original_question = question if not shared.is_chat(): question = apply_extensions(question, 'input') - if shared.args.verbose: - print(f'\n\n{question}\n--------------------\n') # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier if any((shared.is_RWKV, shared.is_llamacpp)): + + if shared.args.verbose: + print(f'\n\n{question}\n--------------------\n') + for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: generate_params[k] = state[k] generate_params['token_count'] = state['max_new_tokens'] @@ -166,10 +172,13 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return - input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token']) + input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) original_input_ids = input_ids output = input_ids[0] + if shared.args.verbose: + print(f'\n\n{decode(input_ids[0])}\n--------------------\n') + cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: @@ -179,7 +188,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): stopping_criteria_list = transformers.StoppingCriteriaList() for st in [stopping_strings, state['custom_stopping_strings']]: if type(st) is list and len(st) > 0: - sentinel_token_ids = [encode(string, 0, add_special_tokens=False) for string in st] + sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0]))) break @@ -188,6 +197,8 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): generate_params[k] = state[k] generate_params['eos_token_id'] = eos_token_ids generate_params['stopping_criteria'] = stopping_criteria_list + if state['ban_eos_token']: + generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] else: for k in ['max_new_tokens', 'do_sample', 'temperature']: generate_params[k] = state[k] diff --git a/server.py b/server.py index aa8180ac..36665a46 100644 --- a/server.py +++ b/server.py @@ -263,7 +263,7 @@ def create_settings_menus(default_preset): with gr.Box(): gr.Markdown('Contrastive search') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - with gr.Box(): + gr.Markdown('Beam search (uses a lot of VRAM)') with gr.Row(): with gr.Column(): @@ -272,10 +272,11 @@ def create_settings_menus(default_preset): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Row(): - shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') - - with gr.Row(): + with gr.Group(): + with gr.Row(): + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos token', info='This forces the model to never end the generation prematurely.') + shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') with gr.Accordion('Soft prompt', open=False): @@ -361,7 +362,7 @@ title = 'Text generation web UI' def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'custom_stopping_strings'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings'] if chat: elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode'] return elements diff --git a/settings-template.json b/settings-template.json index e38293df..896c91f8 100644 --- a/settings-template.json +++ b/settings-template.json @@ -11,6 +11,10 @@ "custom_stopping_strings": "", "stop_at_newline": false, "add_bos_token": true, + "ban_eos_token": true, + "truncation_length": 2048, + "truncation_length_min": 0, + "truncation_length_max": 4096, "chat_prompt_size": 2048, "chat_prompt_size_min": 0, "chat_prompt_size_max": 2048,