From d06c34dea5dc3abed1bc2df96b9616ec50d1040d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 31 Jul 2023 13:34:41 -0300 Subject: [PATCH] Add an extension that makes chat replies longer (#3363) --- extensions/long_replies/script.py | 143 ++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 extensions/long_replies/script.py diff --git a/extensions/long_replies/script.py b/extensions/long_replies/script.py new file mode 100644 index 00000000..035e8c9e --- /dev/null +++ b/extensions/long_replies/script.py @@ -0,0 +1,143 @@ +import torch +from modules import chat, shared +from modules.text_generation import ( + decode, + encode, + generate_reply, +) +from transformers import LogitsProcessor +import gradio as gr + +params = { + "display_name": "Long replies", + "is_tab": False, + "min_length": 120, +} + +initial_size = 0 + +class MyLogits(LogitsProcessor): + """ + Manipulates the probabilities for the next token before it gets sampled. + Used in the logits_processor_modifier function below. + """ + def __init__(self): + self.newline_id = shared.tokenizer.encode('\n')[-1] + pass + + def __call__(self, input_ids, scores): + if input_ids.shape[-1] - initial_size < params["min_length"]: + scores[...,self.newline_id] = -1000 + # scores[...,shared.tokenizer.eos_token_id] = -1000 + + # probs = torch.softmax(scores, dim=-1, dtype=torch.float) + # probs[0] /= probs[0].sum() + # scores = torch.log(probs / (1 - probs)) + return scores + +def history_modifier(history): + """ + Modifies the chat history. + Only used in chat mode. + """ + return history + +def state_modifier(state): + """ + Modifies the state variable, which is a dictionary containing the input + values in the UI like sliders and checkboxes. + """ + return state + +def chat_input_modifier(text, visible_text, state): + """ + Modifies the user input string in chat mode (visible_text). + You can also modify the internal representation of the user + input (text) to change how it will appear in the prompt. + """ + return text, visible_text + +def input_modifier(string, state): + """ + In default/notebook modes, modifies the whole prompt. + + In chat mode, it is the same as chat_input_modifier but only applied + to "text", here called "string", and not to "visible_text". + """ + return string + +def bot_prefix_modifier(string, state): + """ + Modifies the prefix for the next bot reply in chat mode. + By default, the prefix will be something like "Bot Name:". + """ + return string + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + """ + Modifies the input ids and embeds. + Used by the multimodal extension to put image embeddings in the prompt. + Only used by loaders that use the transformers library for sampling. + """ + + global initial_size + initial_size = input_ids.shape[-1] + + return prompt, input_ids, input_embeds + +def logits_processor_modifier(processor_list, input_ids): + """ + Adds logits processors to the list, allowing you to access and modify + the next token probabilities. + Only used by loaders that use the transformers library for sampling. + """ + processor_list.append(MyLogits()) + return processor_list + +def output_modifier(string, state): + """ + Modifies the LLM output before it gets presented. + + In chat mode, the modified version goes into history['visible'], + and the original version goes into history['internal']. + """ + return string + +def custom_generate_chat_prompt(user_input, state, **kwargs): + """ + Replaces the function that generates the prompt from the chat history. + Only used in chat mode. + """ + result = chat.generate_chat_prompt(user_input, state, **kwargs) + return result + +def custom_css(): + """ + Returns a CSS string that gets appended to the CSS for the webui. + """ + return '' + +def custom_js(): + """ + Returns a javascript string that gets appended to the javascript + for the webui. + """ + return '' + +def setup(): + """ + Gets executed only once, when the extension is imported. + """ + pass + +def ui(): + """ + Gets executed when the UI is drawn. Custom gradio elements and + their corresponding event handlers should be defined here. + + To learn about gradio components, check out the docs: + https://gradio.app/docs/ + """ + + min_length = gr.Slider(0, 800, step=10, value=params['min_length'], label='Minimum reply length') + min_length.change(lambda x: params.update({'min_length': x}), min_length, None)