From 474dc7355aad11103a952ea936c2fed1476ff0c1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 13 Jun 2023 20:34:35 -0300 Subject: [PATCH] Allow API requests to use parameter presets --- api-examples/api-example-chat-stream.py | 6 ++- api-examples/api-example-chat.py | 6 ++- api-examples/api-example-stream.py | 7 ++- api-examples/api-example.py | 7 ++- extensions/api/util.py | 8 ++++ modules/presets.py | 55 +++++++++++++++++++++++ modules/utils.py | 5 +++ server.py | 60 +++---------------------- 8 files changed, 96 insertions(+), 58 deletions(-) create mode 100644 modules/presets.py diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index 7314563f..1b80a33d 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -19,6 +19,7 @@ async def run(user_input, history): # Note: the selected defaults change from time to time. request = { 'user_input': user_input, + 'max_new_tokens': 250, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', @@ -32,7 +33,9 @@ async def run(user_input, history): 'chat_generation_attempts': 1, 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', - 'max_new_tokens': 250, + # Generation params. If 'preset' is set to different than 'None', the values + # in presets/preset-name.yaml are used instead of the individual numbers. + 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, @@ -52,6 +55,7 @@ async def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index 8ea6ed1e..fb2847d4 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -13,6 +13,7 @@ URI = f'http://{HOST}/api/v1/chat' def run(user_input, history): request = { 'user_input': user_input, + 'max_new_tokens': 250, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', @@ -26,7 +27,9 @@ def run(user_input, history): 'chat_generation_attempts': 1, 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', - 'max_new_tokens': 250, + # Generation params. If 'preset' is set to different than 'None', the values + # in presets/preset-name.yaml are used instead of the individual numbers. + 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, @@ -46,6 +49,7 @@ def run(user_input, history): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 43cde299..64d4e05f 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -20,8 +20,12 @@ async def run(context): request = { 'prompt': context, 'max_new_tokens': 250, + + # Generation params. If 'preset' is set to different than 'None', the values + # in presets/preset-name.yaml are used instead of the individual numbers. + 'preset': 'None', 'do_sample': True, - 'temperature': 1.3, + 'temperature': 0.7, 'top_p': 0.1, 'typical_p': 1, 'epsilon_cutoff': 0, # In units of 1e-4 @@ -39,6 +43,7 @@ async def run(context): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, diff --git a/api-examples/api-example.py b/api-examples/api-example.py index 4e4a7f66..54a4f371 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -12,8 +12,12 @@ def run(prompt): request = { 'prompt': prompt, 'max_new_tokens': 250, + + # Generation params. If 'preset' is set to different than 'None', the values + # in presets/preset-name.yaml are used instead of the individual numbers. + 'preset': 'None', 'do_sample': True, - 'temperature': 1.3, + 'temperature': 0.7, 'top_p': 0.1, 'typical_p': 1, 'epsilon_cutoff': 0, # In units of 1e-4 @@ -31,6 +35,7 @@ def run(prompt): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, + 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, diff --git a/extensions/api/util.py b/extensions/api/util.py index 59a015de..2174dd81 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -5,6 +5,7 @@ from typing import Callable, Optional from modules import shared from modules.chat import load_character_memoized +from modules.presets import load_preset_memoized def build_parameters(body, chat=False): @@ -40,6 +41,13 @@ def build_parameters(body, chat=False): 'stopping_strings': body.get('stopping_strings', []), } + preset_name = body.get('preset', 'None') + if preset_name not in ['None', None, '']: + print(preset_name) + preset = load_preset_memoized(preset_name) + print(preset) + generate_params.update(preset) + if chat: character = body.get('character') instruction_template = body.get('instruction_template') diff --git a/modules/presets.py b/modules/presets.py new file mode 100644 index 00000000..b954d38a --- /dev/null +++ b/modules/presets.py @@ -0,0 +1,55 @@ +import functools +from pathlib import Path + +import yaml + + +def load_preset(name): + generate_params = { + 'do_sample': True, + 'temperature': 1, + 'top_p': 1, + 'typical_p': 1, + 'epsilon_cutoff': 0, + 'eta_cutoff': 0, + 'tfs': 1, + 'top_a': 0, + 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, + 'top_k': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'min_length': 0, + 'length_penalty': 1, + 'no_repeat_ngram_size': 0, + 'early_stopping': False, + 'mirostat_mode': 0, + 'mirostat_tau': 5.0, + 'mirostat_eta': 0.1, + } + + with open(Path(f'presets/{name}.yaml'), 'r') as infile: + preset = yaml.safe_load(infile) + + for k in preset: + generate_params[k] = preset[k] + + generate_params['temperature'] = min(1.99, generate_params['temperature']) + return generate_params + + +@functools.cache +def load_preset_memoized(name): + return load_preset(name) + + +def load_preset_for_ui(name, state): + generate_params = load_preset(name) + state.update(generate_params) + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] + + +def generate_preset_yaml(state): + data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} + return yaml.dump(data, sort_keys=False) + diff --git a/modules/utils.py b/modules/utils.py index 2fe72525..1535ecdc 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -1,5 +1,6 @@ import os import re +from datetime import datetime from pathlib import Path from modules import shared @@ -41,6 +42,10 @@ def delete_file(fname): logger.info(f'Deleted {fname}.') +def current_time(): + return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" + + def atoi(text): return int(text) if text.isdigit() else text.lower() diff --git a/server.py b/server.py index 398bedf8..e3149a19 100644 --- a/server.py +++ b/server.py @@ -33,7 +33,6 @@ import re import sys import time import traceback -from datetime import datetime from functools import partial from pathlib import Path from threading import Lock @@ -44,7 +43,7 @@ import yaml from PIL import Image import modules.extensions as extensions_module -from modules import chat, shared, training, ui, utils +from modules import chat, presets, shared, training, ui, utils from modules.extensions import apply_extensions from modules.github import clone_or_pull_repository from modules.html_generator import chat_html_wrapper @@ -80,53 +79,6 @@ def load_lora_wrapper(selected_loras): yield ("Successfuly applied the LoRAs") -def load_preset_values(preset_menu, state, return_dict=False): - generate_params = { - 'do_sample': True, - 'temperature': 1, - 'top_p': 1, - 'typical_p': 1, - 'epsilon_cutoff': 0, - 'eta_cutoff': 0, - 'tfs': 1, - 'top_a': 0, - 'repetition_penalty': 1, - 'encoder_repetition_penalty': 1, - 'top_k': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'min_length': 0, - 'length_penalty': 1, - 'no_repeat_ngram_size': 0, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5.0, - 'mirostat_eta': 0.1, - } - - with open(Path(f'presets/{preset_menu}.yaml'), 'r') as infile: - preset = yaml.safe_load(infile) - - for k in preset: - generate_params[k] = preset[k] - - generate_params['temperature'] = min(1.99, generate_params['temperature']) - if return_dict: - return generate_params - else: - state.update(generate_params) - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] - - -def generate_preset_yaml(state): - data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} - return yaml.dump(data, sort_keys=False) - - -def current_time(): - return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" - - def load_prompt(fname): if fname in ['None', '']: return '' @@ -251,7 +203,7 @@ def get_model_specific_settings(model): return model_settings -def load_model_specific_settings(model, state, return_dict=False): +def load_model_specific_settings(model, state): model_settings = get_model_specific_settings(model) for k in model_settings: if k in state: @@ -448,7 +400,7 @@ def create_chat_settings_menus(): def create_settings_menus(default_preset): - generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) + generate_params = presets.load_preset(default_preset) with gr.Row(): with gr.Column(): with gr.Row(): @@ -515,7 +467,7 @@ def create_settings_menus(default_preset): shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) + shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) def create_file_saving_menus(): @@ -578,7 +530,7 @@ def create_file_saving_event_handlers(): shared.gradio['save_preset'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then( + presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then( lambda: 'presets/', None, shared.gradio['save_root']).then( lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then( lambda: gr.update(visible=True), None, shared.gradio['file_saver']) @@ -1043,7 +995,7 @@ def create_interface(): shared.gradio['save_prompt'].click( lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then( lambda: 'prompts/', None, shared.gradio['save_root']).then( - lambda: current_time() + '.txt', None, shared.gradio['save_filename']).then( + lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then( lambda: gr.update(visible=True), None, shared.gradio['file_saver']) shared.gradio['delete_prompt'].click(