From 3f3e42e26cb6b8e56af7eada4f441d846b5f5969 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 01:22:15 -0300 Subject: [PATCH] Refactor several function calls and the API --- api-example-stream.py | 18 ++-------- api-example.py | 21 +++--------- extensions/api/script.py | 35 ++++++++++--------- extensions/send_pictures/script.py | 5 ++- modules/api.py | 38 +++++++++++++++++++++ modules/chat.py | 43 +++++++++++++----------- modules/text_generation.py | 54 +++++++++++++----------------- server.py | 51 ++++++++++++++++++---------- 8 files changed, 147 insertions(+), 118 deletions(-) create mode 100644 modules/api.py diff --git a/api-example-stream.py b/api-example-stream.py index e87fb74c..32eefc7e 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -36,6 +36,7 @@ async def run(context): 'early_stopping': False, 'seed': -1, } + payload = json.dumps([context, params]) session = random_hash() async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: @@ -54,22 +55,7 @@ async def run(context): "session_hash": session, "fn_index": 12, "data": [ - context, - params['max_new_tokens'], - params['do_sample'], - params['temperature'], - params['top_p'], - params['typical_p'], - params['repetition_penalty'], - params['encoder_repetition_penalty'], - params['top_k'], - params['min_length'], - params['no_repeat_ngram_size'], - params['num_beams'], - params['penalty_alpha'], - params['length_penalty'], - params['early_stopping'], - params['seed'], + payload ] })) case "process_starts": diff --git a/api-example.py b/api-example.py index 0349824b..10be0a88 100644 --- a/api-example.py +++ b/api-example.py @@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL, allowing you to use the API remotely. ''' +import json + import requests # Server address @@ -38,24 +40,11 @@ params = { # Input prompt prompt = "What I would like to say is the following: " +payload = json.dumps([prompt, params]) + response = requests.post(f"http://{server}:7860/run/textgen", json={ "data": [ - prompt, - params['max_new_tokens'], - params['do_sample'], - params['temperature'], - params['top_p'], - params['typical_p'], - params['repetition_penalty'], - params['encoder_repetition_penalty'], - params['top_k'], - params['min_length'], - params['no_repeat_ngram_size'], - params['num_beams'], - params['penalty_alpha'], - params['length_penalty'], - params['early_stopping'], - params['seed'], + payload ] }).json() diff --git a/extensions/api/script.py b/extensions/api/script.py index 20562cc6..6726d61d 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler): prompt_lines.pop(0) prompt = '\n'.join(prompt_lines) + generate_params = { + 'max_new_tokens': int(body.get('max_length', 200)), + 'do_sample': bool(body.get('do_sample', True)), + 'temperature': float(body.get('temperature', 0.5)), + 'top_p': float(body.get('top_p', 1)), + 'typical_p': float(body.get('typical', 1)), + 'repetition_penalty': float(body.get('rep_pen', 1.1)), + 'encoder_repetition_penalty': 1, + 'top_k': int(body.get('top_k', 0)), + 'min_length': int(body.get('min_length', 0)), + 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)), + 'num_beams': int(body.get('num_beams',1)), + 'penalty_alpha': float(body.get('penalty_alpha', 0)), + 'length_penalty': float(body.get('length_penalty', 1)), + 'early_stopping': bool(body.get('early_stopping', False)), + 'seed': int(body.get('seed', -1)), + } generator = generate_reply( - question = prompt, - max_new_tokens = int(body.get('max_length', 200)), - do_sample=bool(body.get('do_sample', True)), - temperature=float(body.get('temperature', 0.5)), - top_p=float(body.get('top_p', 1)), - typical_p=float(body.get('typical', 1)), - repetition_penalty=float(body.get('rep_pen', 1.1)), - encoder_repetition_penalty=1, - top_k=int(body.get('top_k', 0)), - min_length=int(body.get('min_length', 0)), - no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)), - num_beams=int(body.get('num_beams',1)), - penalty_alpha=float(body.get('penalty_alpha', 0)), - length_penalty=float(body.get('length_penalty', 1)), - early_stopping=bool(body.get('early_stopping', False)), - seed=int(body.get('seed', -1)), + prompt, + generate_params, stopping_strings=body.get('stopping_strings', []), ) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index b6305bdc..d2401dff 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -2,12 +2,11 @@ import base64 from io import BytesIO import gradio as gr -import modules.chat as chat -import modules.shared as shared import torch -from PIL import Image from transformers import BlipForConditionalGeneration, BlipProcessor +from modules import chat, shared + # If 'state' is True, will hijack the next chat generation with # custom input text given by 'value' in the format [text, visible_text] input_hijack = { diff --git a/modules/api.py b/modules/api.py new file mode 100644 index 00000000..26249fd7 --- /dev/null +++ b/modules/api.py @@ -0,0 +1,38 @@ +import json + +import gradio as gr + +from modules import shared +from modules.text_generation import generate_reply + + +def generate_reply_wrapper(string): + generate_params = { + 'do_sample': True, + 'temperature': 1, + 'top_p': 1, + 'typical_p': 1, + 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, + 'top_k': 50, + 'num_beams': 1, + 'penalty_alpha': 0, + 'min_length': 0, + 'length_penalty': 1, + 'no_repeat_ngram_size': 0, + 'early_stopping': False, + } + params = json.loads(string) + for k in params[1]: + generate_params[k] = params[1][k] + for i in generate_reply(params[0], generate_params): + yield i + +def create_apis(): + t1 = gr.Textbox(visible=False) + t2 = gr.Textbox(visible=False) + dummy = gr.Button(visible=False) + + input_params = [t1] + output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']] + dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen') diff --git a/modules/chat.py b/modules/chat.py index 1140b5fa..f4ddf427 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -18,7 +18,12 @@ from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False): +def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs): + is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False + end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' + impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False + also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False + user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] @@ -91,9 +96,9 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): reply = fix_newlines(reply) return reply, next_character_found -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False, mode="cai-chat", end_of_turn=""): +def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): just_started = True - eos_token = '\n' if stop_at_newline else None + eos_token = '\n' if generate_state['stop_at_newline'] else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -112,11 +117,11 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical visible_text = text text = apply_extensions(text, "input") - is_instruct = mode == 'instruct' + kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) + prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) else: - prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) + prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) # Yield *Is typing...* if not regenerate: @@ -124,13 +129,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Generate cumulative_reply = '' - for i in range(chat_generation_attempts): + for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply # Extracting the reply - reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) + reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) visible_reply = re.sub("(||{{user}})", name1_original, reply) visible_reply = apply_extensions(visible_reply, "output") @@ -155,23 +160,23 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): - eos_token = '\n' if stop_at_newline else None +def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True, end_of_turn=end_of_turn) + prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn) # Yield *Is typing...* yield shared.processing_message cumulative_reply = '' - for i in range(chat_generation_attempts): + for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply - reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) + reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) yield reply if next_character_found: break @@ -181,11 +186,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): - for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=False, mode=mode, end_of_turn=end_of_turn): +def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): yield chat_html_wrapper(history, name1, name2, mode) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): +def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) else: @@ -193,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode) - for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True, mode=mode, end_of_turn=end_of_turn): + for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True): shared.history['visible'][-1] = [last_visible[0], history[-1][1]] yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) diff --git a/modules/text_generation.py b/modules/text_generation.py index 406c4548..93f0789b 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -102,10 +102,11 @@ def set_manual_seed(seed): def stop_everything_event(): shared.stop_everything = True -def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]): +def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): clear_torch_cache() - set_manual_seed(seed) + set_manual_seed(generate_state['seed']) shared.stop_everything = False + generate_params = {} t0 = time.time() original_question = question @@ -117,9 +118,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier if any((shared.is_RWKV, shared.is_llamacpp)): + for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: + generate_params[k] = generate_state[k] + generate_params["token_count"] = generate_state["max_new_tokens"] try: if shared.args.no_stream: - reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) + reply = shared.model.generate(context=question, **generate_params) output = original_question+reply if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") @@ -130,7 +134,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. - for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty): + for reply in shared.model.generate_with_streaming(context=question, **generate_params): output = original_question+reply if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") @@ -145,7 +149,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") return - input_ids = encode(question, max_new_tokens) + input_ids = encode(question, generate_state['max_new_tokens']) original_input_ids = input_ids output = input_ids[0] @@ -158,33 +162,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) - generate_params = {} + generate_params["max_new_tokens"] = generate_state['max_new_tokens'] if not shared.args.flexgen: - generate_params.update({ - "max_new_tokens": max_new_tokens, - "eos_token_id": eos_token_ids, - "stopping_criteria": stopping_criteria_list, - "do_sample": do_sample, - "temperature": temperature, - "top_p": top_p, - "typical_p": typical_p, - "repetition_penalty": repetition_penalty, - "encoder_repetition_penalty": encoder_repetition_penalty, - "top_k": top_k, - "min_length": min_length if shared.args.no_stream else 0, - "no_repeat_ngram_size": no_repeat_ngram_size, - "num_beams": num_beams, - "penalty_alpha": penalty_alpha, - "length_penalty": length_penalty, - "early_stopping": early_stopping, - }) + for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]: + generate_params[k] = generate_state[k] + generate_params["eos_token_id"] = eos_token_ids + generate_params["stopping_criteria"] = stopping_criteria_list + if shared.args.no_stream: + generate_params["min_length"] = 0 else: - generate_params.update({ - "max_new_tokens": max_new_tokens if shared.args.no_stream else 8, - "do_sample": do_sample, - "temperature": temperature, - "stop": eos_token_ids[-1], - }) + for k in ["do_sample", "temperature"]: + generate_params[k] = generate_state[k] + generate_params["stop"] = generate_state["eos_token_ids"][-1] + if not shared.args.no_stream: + generate_params["max_new_tokens"] = 8 + if shared.args.no_cache: generate_params.update({"use_cache": False}) if shared.args.deepspeed: @@ -244,7 +236,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: - for i in range(max_new_tokens//8+1): + for i in range(generate_state['max_new_tokens']//8+1): clear_torch_cache() with torch.no_grad(): output = shared.model.generate(**generate_params)[0] diff --git a/server.py b/server.py index 8bcb6502..f00e412c 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,7 @@ import gradio as gr from PIL import Image import modules.extensions as extensions_module -from modules import chat, shared, training, ui +from modules import chat, shared, training, ui, api from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt @@ -85,7 +85,7 @@ def load_lora_wrapper(selected_lora): add_lora_to_model(selected_lora) return selected_lora -def load_preset_values(preset_menu, return_dict=False): +def load_preset_values(preset_menu, state, return_dict=False): generate_params = { 'do_sample': True, 'temperature': 1, @@ -107,13 +107,13 @@ def load_preset_values(preset_menu, return_dict=False): i = i.rstrip(',').strip().split('=') if len(i) == 2 and i[0].strip() != 'tokens': generate_params[i[0].strip()] = eval(i[1].strip()) - generate_params['temperature'] = min(1.99, generate_params['temperature']) if return_dict: return generate_params else: - return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] + state.update(generate_params) + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: @@ -170,7 +170,10 @@ def create_prompt_menus(): shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) def create_settings_menus(default_preset): - generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) + for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: + generate_params[k] = shared.settings[k] + shared.gradio['generate_state'] = gr.State(generate_params) with gr.Row(): with gr.Column(): @@ -221,17 +224,16 @@ def create_settings_menus(default_preset): with gr.Row(): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) - shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True) - shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) - shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) + shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) + shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) def set_interface_arguments(interface_mode, extensions, bool_active): modes = ["default", "notebook", "chat", "cai_chat"] cmd_list = vars(shared.args) bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] - #int_list = [k for k in cmd_list if type(k) is int] shared.args.extensions = extensions for k in modes[1:]: @@ -372,11 +374,11 @@ def create_interface(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) with gr.Column(): shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') + shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts', 'Chat mode', 'end_of_turn']] + shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']] def set_chat_input(textbox): return textbox, "" @@ -456,9 +458,9 @@ def create_interface(): with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") @@ -489,9 +491,9 @@ def create_interface(): with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) @@ -524,6 +526,21 @@ def create_interface(): if shared.args.extensions is not None: extensions_module.create_extensions_block() + def change_dict_value(d, key, value): + d[key] = value + return d + + for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']: + if k not in shared.gradio: + continue + if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]: + shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state']) + else: + shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state']) + + if not shared.is_chat(): + api.create_apis() + # Authentication auth = None if shared.args.gradio_auth_path is not None: