Merge branch 'main' into Honkware-main

This commit is contained in:
oobabooga 2023-07-04 18:50:07 -07:00
commit 84d6c93d0d
35 changed files with 821 additions and 453 deletions

View File

@ -193,6 +193,7 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | Show this help message and exit. | | `-h`, `--help` | Show this help message and exit. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode. | | `--chat` | Launch the web UI in chat mode. |
| `--multi-user` | Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental. |
| `--character CHARACTER` | The name of the character to load in chat mode by default. | | `--character CHARACTER` | The name of the character to load in chat mode by default. |
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
| `--lora LORA [LORA ...]` | The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. | | `--lora LORA [LORA ...]` | The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. |
@ -268,6 +269,7 @@ Optionally, you can use the following command-line flags:
|`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. `20,7,7` | |`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. `20,7,7` |
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|`--compress_pos_emb COMPRESS_POS_EMB` | Positional embeddings compression factor. Should typically be set to max_seq_len / 2048. | |`--compress_pos_emb COMPRESS_POS_EMB` | Positional embeddings compression factor. Should typically be set to max_seq_len / 2048. |
|`--alpha_value ALPHA_VALUE` | Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both. `
#### GPTQ-for-LLaMa #### GPTQ-for-LLaMa

View File

@ -44,6 +44,7 @@ async def run(user_input, history):
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'repetition_penalty': 1.18, 'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'top_k': 40, 'top_k': 40,
'min_length': 0, 'min_length': 0,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,

View File

@ -38,6 +38,7 @@ def run(user_input, history):
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'repetition_penalty': 1.18, 'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'top_k': 40, 'top_k': 40,
'min_length': 0, 'min_length': 0,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,

View File

@ -33,6 +33,7 @@ async def run(context):
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'repetition_penalty': 1.18, 'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'top_k': 40, 'top_k': 40,
'min_length': 0, 'min_length': 0,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,

View File

@ -25,6 +25,7 @@ def run(prompt):
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'repetition_penalty': 1.18, 'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'top_k': 40, 'top_k': 40,
'min_length': 0, 'min_length': 0,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,

View File

@ -6,6 +6,12 @@
padding-top: 2.5rem padding-top: 2.5rem
} }
.small-button {
max-width: 171px;
height: 39.594px;
align-self: end;
}
.refresh-button { .refresh-button {
max-width: 4.4em; max-width: 4.4em;
min-width: 2.2em !important; min-width: 2.2em !important;
@ -50,7 +56,7 @@ ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab { #main, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab {
border: 0; border: 0;
} }
@ -121,10 +127,6 @@ button {
font-size: 14px !important; font-size: 14px !important;
} }
.small-button {
max-width: 171px;
}
.file-saver { .file-saver {
position: fixed !important; position: fixed !important;
top: 50%; top: 50%;
@ -149,3 +151,7 @@ button {
.checkboxgroup-table div { .checkboxgroup-table div {
display: grid !important; display: grid !important;
} }
.markdown ul ol {
font-size: 100% !important;
}

View File

@ -38,11 +38,11 @@ script.py may define the special functions and variables below.
| `def ui()` | Creates custom gradio elements when the UI is launched. | | `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| `def custom_js()` | Same as above but for javascript. | | `def custom_js()` | Same as above but for javascript. |
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def input_modifier(string, state)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def output_modifier(string, state)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. |
| `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. |
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply. |
| `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_reply(...)` | Overrides the main text generation function. |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. | | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. |

View File

@ -23,6 +23,7 @@ ExLlama only uses the following parameters:
* top_p * top_p
* top_k * top_k
* repetition_penalty * repetition_penalty
* repetition_penalty_range
* typical_p * typical_p
### RWKV ### RWKV

View File

@ -18,12 +18,16 @@ from pathlib import Path
import requests import requests
import tqdm import tqdm
from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
class ModelDownloader: class ModelDownloader:
def __init__(self): def __init__(self, max_retries):
self.s = requests.Session() self.s = requests.Session()
if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
@ -212,6 +216,7 @@ if __name__ == '__main__':
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
args = parser.parse_args() args = parser.parse_args()
branch = args.branch branch = args.branch
@ -221,7 +226,7 @@ if __name__ == '__main__':
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
sys.exit() sys.exit()
downloader = ModelDownloader() downloader = ModelDownloader(max_retries=args.max_retries)
# Cleaning up the model/branch names # Cleaning up the model/branch names
try: try:
model, branch = downloader.sanitize_model_and_branch_names(model, branch) model, branch = downloader.sanitize_model_and_branch_names(model, branch)

View File

@ -72,7 +72,6 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
user_input = body['user_input'] user_input = body['user_input']
history = body['history']
regenerate = body.get('regenerate', False) regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False) _continue = body.get('_continue', False)
@ -80,9 +79,9 @@ class Handler(BaseHTTPRequestHandler):
generate_params['stream'] = False generate_params['stream'] = False
generator = generate_chat_reply( generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = history answer = generate_params['history']
for a in generator: for a in generator:
answer = a answer = a

View File

@ -55,14 +55,13 @@ async def _handle_connection(websocket, path):
body = json.loads(message) body = json.loads(message)
user_input = body['user_input'] user_input = body['user_input']
history = body['history']
generate_params = build_parameters(body, chat=True) generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True generate_params['stream'] = True
regenerate = body.get('regenerate', False) regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False) _continue = body.get('_continue', False)
generator = generate_chat_reply( generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0 message_num = 0
for a in generator: for a in generator:

View File

@ -21,6 +21,7 @@ def build_parameters(body, chat=False):
'tfs': float(body.get('tfs', 1)), 'tfs': float(body.get('tfs', 1)),
'top_a': float(body.get('top_a', 0)), 'top_a': float(body.get('top_a', 0)),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)), 'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)), 'min_length': int(body.get('min_length', 0)),
@ -64,6 +65,7 @@ def build_parameters(body, chat=False):
'context_instruct': context_instruct, 'context_instruct': context_instruct,
'turn_template': turn_template, 'turn_template': turn_template,
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])), 'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
'history': body.get('history', {'internal': [], 'visible': []})
}) })
return generate_params return generate_params

View File

@ -3,7 +3,9 @@ from pathlib import Path
import elevenlabs import elevenlabs
import gradio as gr import gradio as gr
from modules import chat, shared from modules import chat, shared
from modules.utils import gradio
params = { params = {
'activate': True, 'activate': True,
@ -35,24 +37,24 @@ def refresh_voices_dd():
return gr.Dropdown.update(value=all_voices[0], choices=all_voices) return gr.Dropdown.update(value=all_voices[0], choices=all_voices)
def remove_tts_from_history(): def remove_tts_from_history(history):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] history['visible'][i] = [history['visible'][i][0], entry[1]]
return history
def toggle_text_in_history(): def toggle_text_in_history(history):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
if visible_reply.startswith('<audio'): if visible_reply.startswith('<audio'):
if params['show_text']: if params['show_text']:
reply = shared.history['internal'][i][1] reply = history['internal'][i][1]
shared.history['visible'][i] = [ history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"
]
else: else:
shared.history['visible'][i] = [ history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"
] return history
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
@ -150,25 +152,24 @@ def ui():
convert_cancel = gr.Button('Cancel', visible=False) convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
if shared.is_chat():
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click( convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then( remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history # Toggle message text in history
show_text.change( show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then( lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then( toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({'activate': x}), activate, None) activate.change(lambda x: params.update({'activate': x}), activate, None)

View File

@ -29,6 +29,7 @@ default_req_params = {
'top_p': 1.0, 'top_p': 1.0,
'top_k': 1, 'top_k': 1,
'repetition_penalty': 1.18, 'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0, 'encoder_repetition_penalty': 1.0,
'suffix': None, 'suffix': None,
'stream': False, 'stream': False,

View File

@ -10,7 +10,7 @@ import requests
import torch import torch
from PIL import Image from PIL import Image
import modules.shared as shared from modules import shared
from modules.models import reload_model, unload_model from modules.models import reload_model, unload_model
from modules.ui import create_refresh_button from modules.ui import create_refresh_button
@ -126,7 +126,7 @@ def input_modifier(string):
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description, character):
global params global params
@ -160,7 +160,7 @@ def get_SD_pictures(description):
if params['save_img']: if params['save_img']:
img_data = base64.b64decode(img_str) img_data = base64.b64decode(img_str)
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}' variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}'
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png') output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
output_file.parent.mkdir(parents=True, exist_ok=True) output_file.parent.mkdir(parents=True, exist_ok=True)
@ -186,7 +186,7 @@ def get_SD_pictures(description):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string, state):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
@ -213,7 +213,7 @@ def output_modifier(string):
else: else:
text = string text = string
string = get_SD_pictures(string) + "\n" + text string = get_SD_pictures(string, state['character_menu']) + "\n" + text
return string return string

View File

@ -3,9 +3,10 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
from modules import chat, shared
from extensions.silero_tts import tts_preprocessor from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.utils import gradio
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
@ -56,20 +57,24 @@ def load_model():
return model return model
def remove_tts_from_history(): def remove_tts_from_history(history):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] history['visible'][i] = [history['visible'][i][0], entry[1]]
return history
def toggle_text_in_history(): def toggle_text_in_history(history):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
if visible_reply.startswith('<audio'): if visible_reply.startswith('<audio'):
if params['show_text']: if params['show_text']:
reply = shared.history['internal'][i][1] reply = history['internal'][i][1]
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"] history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else: else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"] history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return history
def state_modifier(state): def state_modifier(state):
@ -80,7 +85,7 @@ def state_modifier(state):
return state return state
def input_modifier(string): def input_modifier(string, state):
if not params['activate']: if not params['activate']:
return string return string
@ -99,7 +104,7 @@ def history_modifier(history):
return history return history
def output_modifier(string): def output_modifier(string, state):
global model, current_params, streaming_state global model, current_params, streaming_state
for i in params: for i in params:
if params[i] != current_params[i]: if params[i] != current_params[i]:
@ -116,7 +121,7 @@ def output_modifier(string):
if string == '': if string == '':
string = '*Empty reply, try regenerating*' string = '*Empty reply, try regenerating*'
else: else:
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav') output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch']) prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>' silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
@ -155,23 +160,24 @@ def ui():
gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)') gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)')
if shared.is_chat():
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click( convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then( remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history # Toggle message text in history
show_text.change( show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then( lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then( toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)

View File

@ -96,6 +96,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
def custom_generate_chat_prompt(user_input, state, **kwargs): def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector global chat_collector
history = state['history']
if state['mode'] == 'instruct': if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=params['chunk_count']) results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
@ -104,29 +106,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
def make_single_exchange(id_): def make_single_exchange(id_):
output = '' output = ''
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n" output += f"{state['name1']}: {history['internal'][id_][0]}\n"
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n" output += f"{state['name2']}: {history['internal'][id_][1]}\n"
return output return output
if len(shared.history['internal']) > params['chunk_count'] and user_input != '': if len(history['internal']) > params['chunk_count'] and user_input != '':
chunks = [] chunks = []
hist_size = len(shared.history['internal']) hist_size = len(history['internal'])
for i in range(hist_size-1): for i in range(hist_size-1):
chunks.append(make_single_exchange(i)) chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector) add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(shared.history['internal'][-1] + [user_input]) query = '\n'.join(history['internal'][-1] + [user_input])
try: try:
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight']) best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n' additional_context = '\n'
for id_ in best_ids: for id_ in best_ids:
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_) additional_context += make_single_exchange(id_)
logger.warning(f'Adding the following new context:\n{additional_context}') logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = { kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids], 'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': '' 'visible': ''
} }
except RuntimeError: except RuntimeError:

View File

@ -106,6 +106,8 @@ def add_lora_transformers(lora_names):
# If any LoRA needs to be removed, start over # If any LoRA needs to be removed, start over
if len(removed_set) > 0: if len(removed_set) > 0:
# shared.model may no longer be PeftModel
if hasattr(shared.model, 'disable_adapter'):
shared.model.disable_adapter() shared.model.disable_adapter()
shared.model = shared.model.base_model.model shared.model = shared.model.base_model.model

View File

@ -1,12 +1,11 @@
import base64 import base64
import copy import copy
import functools import functools
import io
import json import json
import re import re
from datetime import datetime
from pathlib import Path from pathlib import Path
import gradio as gr
import yaml import yaml
from PIL import Image from PIL import Image
@ -19,7 +18,12 @@ from modules.text_generation import (
get_encoded_length, get_encoded_length,
get_max_prompt_length get_max_prompt_length
) )
from modules.utils import delete_file, replace_all, save_file from modules.utils import (
delete_file,
get_available_characters,
replace_all,
save_file
)
def get_turn_substrings(state, instruct=False): def get_turn_substrings(state, instruct=False):
@ -53,7 +57,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False) impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False) _continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False) also_return_rows = kwargs.get('also_return_rows', False)
history = kwargs.get('history', shared.history)['internal'] history = kwargs.get('history', state['history'])['internal']
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# Find the maximum prompt size # Find the maximum prompt size
@ -75,10 +79,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
if impersonate: if impersonate:
wrapper += substrings['user_turn_stripped'].rstrip(' ') wrapper += substrings['user_turn_stripped'].rstrip(' ')
elif _continue: elif _continue:
wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped']) wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'], state)
wrapper += history[-1][1] wrapper += history[-1][1]
else: else:
wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')) wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state)
else: else:
wrapper = '<|prompt|>' wrapper = '<|prompt|>'
@ -112,7 +116,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Add the character prefix # Add the character prefix
if state['mode'] != 'chat-instruct': if state['mode'] != 'chat-instruct':
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' '))) rows.append(apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state))
while len(rows) > min_rows and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) >= max_length: while len(rows) > min_rows and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) >= max_length:
rows.pop(1) rows.pop(1)
@ -152,7 +156,8 @@ def get_stopping_strings(state):
return stopping_strings return stopping_strings
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True):
history = state['history']
output = copy.deepcopy(history) output = copy.deepcopy(history)
output = apply_extensions('history', output) output = apply_extensions('history', output)
state = apply_extensions('state', state) state = apply_extensions('state', state)
@ -173,11 +178,11 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
text = apply_extensions('input', text, state)
# *Is typing...* # *Is typing...*
if loading_message: if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
text = apply_extensions('input', text)
else: else:
text, visible_text = output['internal'][-1][0], output['visible'][-1][0] text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate: if regenerate:
@ -214,7 +219,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
if shared.stop_everything: if shared.stop_everything:
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1]) output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
yield output yield output
return return
@ -240,7 +245,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
else: else:
cumulative_reply = reply cumulative_reply = reply
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1]) output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
yield output yield output
@ -273,14 +278,15 @@ def impersonate_wrapper(text, start_with, state):
yield cumulative_reply.lstrip(' ') yield cumulative_reply.lstrip(' ')
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True): def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True):
history = state['history']
if regenerate or _continue: if regenerate or _continue:
text = '' text = ''
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
yield history yield history
return return
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
yield history yield history
@ -288,151 +294,127 @@ def generate_chat_reply(text, history, state, regenerate=False, _continue=False,
def generate_chat_reply_wrapper(text, start_with, state, regenerate=False, _continue=False): def generate_chat_reply_wrapper(text, start_with, state, regenerate=False, _continue=False):
if start_with != '' and not _continue: if start_with != '' and not _continue:
if regenerate: if regenerate:
text = remove_last_message() text, state['history'] = remove_last_message(state['history'])
regenerate = False regenerate = False
_continue = True _continue = True
send_dummy_message(text) send_dummy_message(text, state)
send_dummy_reply(start_with) send_dummy_reply(start_with, state)
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)): for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)):
if i != 0: yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']), history
shared.history = copy.deepcopy(history)
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
def remove_last_message(): def remove_last_message(history):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(history['visible']) > 0 and history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = history['visible'].pop()
shared.history['internal'].pop() history['internal'].pop()
else: else:
last = ['', ''] last = ['', '']
return last[0] return last[0], history
def send_last_reply_to_input(): def send_last_reply_to_input(history):
if len(shared.history['internal']) > 0: if len(history['internal']) > 0:
return shared.history['internal'][-1][1] return history['internal'][-1][1]
else: else:
return '' return ''
def replace_last_reply(text): def replace_last_reply(text, state):
if len(shared.history['visible']) > 0: history = state['history']
shared.history['visible'][-1][1] = text if len(history['visible']) > 0:
shared.history['internal'][-1][1] = apply_extensions("input", text) history['visible'][-1][1] = text
history['internal'][-1][1] = apply_extensions('input', text, state)
return history
def send_dummy_message(text): def send_dummy_message(text, state):
shared.history['visible'].append([text, '']) history = state['history']
shared.history['internal'].append([apply_extensions("input", text), '']) history['visible'].append([text, ''])
history['internal'].append([apply_extensions('input', text, state), ''])
return history
def send_dummy_reply(text): def send_dummy_reply(text, state):
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '': history = state['history']
shared.history['visible'].append(['', '']) if len(history['visible']) > 0 and not history['visible'][-1][1] == '':
shared.history['internal'].append(['', '']) history['visible'].append(['', ''])
history['internal'].append(['', ''])
shared.history['visible'][-1][1] = text history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions("input", text) history['internal'][-1][1] = apply_extensions('input', text, state)
return history
def clear_chat_log(greeting, mode): def clear_chat_log(state):
shared.history['visible'] = [] greeting = state['greeting']
shared.history['internal'] = [] mode = state['mode']
history = state['history']
history['visible'] = []
history['internal'] = []
if mode != 'instruct': if mode != 'instruct':
if greeting != '': if greeting != '':
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]] history['visible'] += [['', apply_extensions('output', greeting, state)]]
save_history(mode)
def redraw_html(name1, name2, mode, style, reset_cache=False):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, style, reset_cache=reset_cache)
def tokenize_dialogue(dialogue, name1, name2):
history = []
messages = []
dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('<start>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
if len(idx) == 0:
return history
for i in range(len(idx) - 1):
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
messages.append(dialogue[idx[-1]:].strip())
entry = ['', '']
for i in messages:
if i.startswith(f'{name1}:'):
entry[0] = i[len(f'{name1}:'):].strip()
elif i.startswith(f'{name2}:'):
entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
history.append(entry)
entry = ['', '']
print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
for row in history:
for column in row:
print("\n")
for line in column.strip().split('\n'):
print("| " + line + "\n")
print("|\n")
print("------------------------------")
return history return history
def save_history(mode, timestamp=False, user_request=False): def redraw_html(history, name1, name2, mode, style, reset_cache=False):
# Instruct mode histories should not be saved as if return chat_html_wrapper(history, name1, name2, mode, style, reset_cache=reset_cache)
# Alpaca or Vicuna were characters
if mode == 'instruct':
if not timestamp:
return
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
if shared.character == 'None' and not user_request:
return
if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
fname = f"{shared.character}_persistent.json"
if not Path('logs').exists():
Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}')
def load_history(file, name1, name2): def save_history(history, path=None):
file = file.decode('utf-8') p = path or Path('logs/exported_history.json')
with open(p, 'w', encoding='utf-8') as f:
f.write(json.dumps(history, indent=4))
return p
def load_history(file, history):
try: try:
file = file.decode('utf-8')
j = json.loads(file) j = json.loads(file)
if 'data' in j: if 'internal' in j and 'visible' in j:
shared.history['internal'] = j['data'] return j
if 'data_visible' in j:
shared.history['visible'] = j['data_visible']
else: else:
shared.history['visible'] = copy.deepcopy(shared.history['internal']) return history
except: except:
shared.history['internal'] = tokenize_dialogue(file, name1, name2) return history
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def save_persistent_history(history, character, mode):
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
save_history(history, path=Path(f'logs/{character}_persistent.json'))
def load_persistent_history(state):
if state['mode'] == 'instruct':
return state['history']
character = state['character_menu']
greeting = state['greeting']
p = Path(f'logs/{character}_persistent.json')
if not shared.args.multi_user and character not in ['None', '', None] and p.exists():
f = json.loads(open(p, 'rb').read())
if 'internal' in f and 'visible' in f:
history = f
else:
history = {'internal': [], 'visible': []}
history['internal'] = f['data']
history['visible'] = f['data_visible']
else:
history = {'internal': [], 'visible': []}
if greeting != "":
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
history['visible'] += [['', apply_extensions('output', greeting, state)]]
return history
def replace_character_names(text, name1, name2): def replace_character_names(text, name1, name2):
@ -467,7 +449,6 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, instruct=False): def load_character(character, name1, name2, instruct=False):
shared.character = character
context = greeting = turn_template = "" context = greeting = turn_template = ""
greeting_field = 'greeting' greeting_field = 'greeting'
picture = None picture = None
@ -476,7 +457,7 @@ def load_character(character, name1, name2, instruct=False):
if Path("cache/pfp_character.png").exists(): if Path("cache/pfp_character.png").exists():
Path("cache/pfp_character.png").unlink() Path("cache/pfp_character.png").unlink()
if character != 'None': if character not in ['None', '', None]:
folder = 'characters' if not instruct else 'characters/instruction-following' folder = 'characters' if not instruct else 'characters/instruction-following'
picture = generate_pfp_cache(character) picture = generate_pfp_cache(character)
for extension in ["yml", "yaml", "json"]: for extension in ["yml", "yaml", "json"]:
@ -526,20 +507,6 @@ def load_character(character, name1, name2, instruct=False):
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
turn_template = shared.settings['turn_template'] turn_template = shared.settings['turn_template']
if not instruct:
shared.history['internal'] = []
shared.history['visible'] = []
if shared.character != 'None' and Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
else:
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Create .json log files since they don't already exist
save_history('instruct' if instruct else 'chat')
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n") return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
@ -567,16 +534,22 @@ def upload_character(json_file, img, tavern=False):
img.save(Path(f'characters/{outfile_name}.png')) img.save(Path(f'characters/{outfile_name}.png'))
logger.info(f'New character saved to "characters/{outfile_name}.json".') logger.info(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return gr.update(value=outfile_name, choices=get_available_characters())
def upload_tavern_character(img, name1, name2): def upload_tavern_character(img, _json):
_img = Image.open(io.BytesIO(img))
_img.getexif()
decoded_string = base64.b64decode(_img.info['chara'])
_json = json.loads(decoded_string)
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), _img, tavern=True) return upload_character(json.dumps(_json), img, tavern=True)
def check_tavern_character(img):
if "chara" not in img.info:
return "Not a TavernAI card", None, None, gr.update(interactive=False)
decoded_string = base64.b64decode(img.info['chara'])
_json = json.loads(decoded_string)
if "data" in _json:
_json = _json["data"]
return _json['name'], _json['description'], _json, gr.update(interactive=True)
def upload_your_profile_picture(img): def upload_your_profile_picture(img):

View File

@ -1,6 +1,8 @@
import sys import sys
from pathlib import Path from pathlib import Path
from torch import version as torch_version
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -52,6 +54,16 @@ class ExllamaModel:
config.set_auto_map(shared.args.gpu_split) config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True config.gpu_peer_fix = True
if shared.args.alpha_value:
config.alpha_value = shared.args.alpha_value
config.calculate_rotary_embedding_base()
if torch_version.hip:
config.rmsnorm_no_half2 = True
config.rope_no_half2 = True
config.matmul_no_half2 = True
config.silu_no_half2 = True
model = ExLlama(config) model = ExLlama(config)
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path)) tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
cache = ExLlamaCache(model) cache = ExLlamaCache(model)
@ -71,6 +83,7 @@ class ExllamaModel:
self.generator.settings.top_k = state['top_k'] self.generator.settings.top_k = state['top_k']
self.generator.settings.typical = state['typical_p'] self.generator.settings.typical = state['typical_p']
self.generator.settings.token_repetition_penalty_max = state['repetition_penalty'] self.generator.settings.token_repetition_penalty_max = state['repetition_penalty']
self.generator.settings.token_repetition_penalty_sustain = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
if state['ban_eos_token']: if state['ban_eos_token']:
self.generator.disallow_tokens([self.tokenizer.eos_token_id]) self.generator.disallow_tokens([self.tokenizer.eos_token_id])
else: else:

View File

@ -98,6 +98,16 @@ class ExllamaHF(PreTrainedModel):
config.set_auto_map(shared.args.gpu_split) config.set_auto_map(shared.args.gpu_split)
config.gpu_peer_fix = True config.gpu_peer_fix = True
if shared.args.alpha_value:
config.alpha_value = shared.args.alpha_value
config.calculate_rotary_embedding_base()
if torch.version.hip:
config.rmsnorm_no_half2 = True
config.rope_no_half2 = True
config.matmul_no_half2 = True
config.silu_no_half2 = True
# This slowes down a bit but align better with autogptq generation. # This slowes down a bit but align better with autogptq generation.
# TODO: Should give user choice to tune the exllama config # TODO: Should give user choice to tune the exllama config
# config.fused_attn = False # config.fused_attn = False

View File

@ -6,6 +6,8 @@ import gradio as gr
import extensions import extensions
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger from modules.logging_colors import logger
from inspect import signature
state = {} state = {}
available_extensions = [] available_extensions = []
@ -52,10 +54,14 @@ def iterator():
# Extension functions that map string -> string # Extension functions that map string -> string
def _apply_string_extensions(function_name, text): def _apply_string_extensions(function_name, text, state):
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, function_name): if hasattr(extension, function_name):
text = getattr(extension, function_name)(text) func = getattr(extension, function_name)
if len(signature(func).parameters) == 2:
text = func(text, state)
else:
text = func(text)
return text return text

View File

@ -14,16 +14,20 @@ def clone_or_pull_repository(github_url):
# Check if the repository is already cloned # Check if the repository is already cloned
if os.path.exists(repo_path): if os.path.exists(repo_path):
yield f"Updating {github_url}..."
# Perform a 'git pull' to update the repository # Perform a 'git pull' to update the repository
try: try:
pull_output = subprocess.check_output(["git", "-C", repo_path, "pull"], stderr=subprocess.STDOUT) pull_output = subprocess.check_output(["git", "-C", repo_path, "pull"], stderr=subprocess.STDOUT)
yield "Done."
return pull_output.decode() return pull_output.decode()
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return str(e) return str(e)
# Clone the repository # Clone the repository
try: try:
yield f"Cloning {github_url}..."
clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT) clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT)
yield "Done."
return clone_output.decode() return clone_output.decode()
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return str(e) return str(e)

View File

@ -129,7 +129,7 @@ def generate_4chan_html(f):
def make_thumbnail(image): def make_thumbnail(image):
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS) image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
if image.size[1] > 470: if image.size[1] > 470:
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS) image = ImageOps.fit(image, (350, 470), Image.LANCZOS)
return image return image
@ -266,8 +266,8 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False): def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False):
if mode == 'instruct': if mode == 'instruct':
return generate_instruct_html(history) return generate_instruct_html(history['visible'])
elif style == 'wpp': elif style == 'wpp':
return generate_chat_html(history, name1, name2) return generate_chat_html(history['visible'], name1, name2)
else: else:
return generate_cai_chat_html(history, name1, name2, style, reset_cache) return generate_cai_chat_html(history['visible'], name1, name2, style, reset_cache)

View File

@ -57,12 +57,14 @@ loaders_and_params = {
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
'compress_pos_emb', 'compress_pos_emb',
'alpha_value',
'exllama_info', 'exllama_info',
], ],
'ExLlama_HF' : [ 'ExLlama_HF' : [
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
'compress_pos_emb', 'compress_pos_emb',
'alpha_value',
'exllama_HF_info', 'exllama_HF_info',
] ]
} }

View File

@ -326,6 +326,7 @@ def clear_torch_cache():
def unload_model(): def unload_model():
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
shared.lora_names = []
clear_torch_cache() clear_torch_cache()

View File

@ -15,6 +15,7 @@ def load_preset(name):
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'repetition_penalty': 1, 'repetition_penalty': 1,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1, 'encoder_repetition_penalty': 1,
'top_k': 0, 'top_k': 0,
'num_beams': 1, 'num_beams': 1,
@ -28,6 +29,7 @@ def load_preset(name):
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
} }
if name not in ['None', None, '']:
with open(Path(f'presets/{name}.yaml'), 'r') as infile: with open(Path(f'presets/{name}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile) preset = yaml.safe_load(infile)
@ -46,9 +48,9 @@ def load_preset_memoized(name):
def load_preset_for_ui(name, state): def load_preset_for_ui(name, state):
generate_params = load_preset(name) generate_params = load_preset(name)
state.update(generate_params) state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def generate_preset_yaml(state): def generate_preset_yaml(state):
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
return yaml.dump(data, sort_keys=False) return yaml.dump(data, sort_keys=False)

View File

@ -5,6 +5,7 @@ import transformers
from transformers import LogitsWarper from transformers import LogitsWarper
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitNormalization, LogitNormalization,
LogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
TemperatureLogitsWarper TemperatureLogitsWarper
) )
@ -121,6 +122,29 @@ class MirostatLogitsWarper(LogitsWarper):
return scores return scores
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''
def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self._range = _range
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
def get_logits_warper_patch(self, generation_config): def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config) warpers = self._get_logits_warper_old(generation_config)
warpers_to_add = LogitsProcessorList() warpers_to_add = LogitsProcessorList()
@ -146,6 +170,19 @@ def get_logits_warper_patch(self, generation_config):
return warpers return warpers
def get_logits_processor_patch(self, **kwargs):
result = self._get_logits_processor_old(**kwargs)
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
repetition_penalty = kwargs['generation_config'].repetition_penalty
if repetition_penalty_range > 0:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
return result
def generation_config_init_patch(self, **kwargs): def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs) self.__init___old(**kwargs)
self.tfs = kwargs.pop("tfs", 1.0) self.tfs = kwargs.pop("tfs", 1.0)
@ -153,11 +190,15 @@ def generation_config_init_patch(self, **kwargs):
self.mirostat_mode = kwargs.pop("mirostat_mode", 0) self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5) self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
def hijack_samplers(): def hijack_samplers():
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__ transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
transformers.GenerationConfig.__init__ = generation_config_init_patch transformers.GenerationConfig.__init__ = generation_config_init_patch

View File

@ -14,8 +14,6 @@ model_name = "None"
lora_names = [] lora_names = []
# Chat variables # Chat variables
history = {'internal': [], 'visible': []}
character = 'None'
stop_everything = False stop_everything = False
processing_message = '*Is typing...*' processing_message = '*Is typing...*'
@ -83,6 +81,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
# Basic settings # Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
parser.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. WARNING: this is highly experimental.')
parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.') parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
@ -151,6 +150,7 @@ parser.add_argument('--desc_act', action='store_true', help='For models that don
parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7") parser.add_argument('--gpu-split', type=str, help="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7")
parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.") parser.add_argument('--max_seq_len', type=int, default=2048, help="Maximum sequence length.")
parser.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.") parser.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.")
parser.add_argument('--alpha_value', type=int, default=1, help="Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both.")
# FlexGen # FlexGen
parser.add_argument('--flexgen', action='store_true', help='DEPRECATED') parser.add_argument('--flexgen', action='store_true', help='DEPRECATED')
@ -204,6 +204,8 @@ if args.trust_remote_code:
logger.warning("trust_remote_code is enabled. This is dangerous.") logger.warning("trust_remote_code is enabled. This is dangerous.")
if args.share: if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if args.multi_user:
logger.warning("The multi-user mode is highly experimental. DO NOT EXPOSE IT TO THE INTERNET.")
def fix_loader_name(name): def fix_loader_name(name):
@ -246,6 +248,15 @@ def is_chat():
return args.chat return args.chat
def get_mode():
if args.chat:
return 'chat'
elif args.notebook:
return 'notebook'
else:
return 'default'
# Loading model-specific settings # Loading model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p: with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists(): if p.exists():

View File

@ -190,7 +190,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
original_question = question original_question = question
if not is_chat: if not is_chat:
state = apply_extensions('state', state) state = apply_extensions('state', state)
question = apply_extensions('input', question) question = apply_extensions('input', question, state)
# Finding the stopping strings # Finding the stopping strings
all_stop_strings = [] all_stop_strings = []
@ -223,14 +223,14 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
break break
if not is_chat: if not is_chat:
reply = apply_extensions('output', reply) reply = apply_extensions('output', reply, state)
yield reply yield reply
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {} generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
generate_params[k] = state[k] generate_params[k] = state[k]
for k in ['epsilon_cutoff', 'eta_cutoff']: for k in ['epsilon_cutoff', 'eta_cutoff']:
@ -262,7 +262,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
generate_params['eos_token_id'] = eos_token_ids generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()); generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
t0 = time.time() t0 = time.time()
try: try:

View File

@ -240,6 +240,21 @@ def backup_adapter(input_folder):
except Exception as e: except Exception as e:
print("An error occurred in backup_adapter:", str(e)) print("An error occurred in backup_adapter:", str(e))
def calc_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params,all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
@ -268,7 +283,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
else: else:
model_id = "llama" model_id = "llama"
if model_type == "PeftModelForCausalLM": if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0: if len(shared.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.") logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else: else:
@ -431,6 +446,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if not always_override: if not always_override:
backup_adapter(lora_file_path) backup_adapter(lora_file_path)
# == get model trainable params
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
@ -540,6 +558,12 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
logger.info("Starting training...") logger.info("Starting training...")
yield "Starting..." yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
if lora_all_param>0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
train_log.update({"base_model_name": shared.model_name}) train_log.update({"base_model_name": shared.model_name})
train_log.update({"base_model_class": shared.model.__class__.__name__}) train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)}) train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})

View File

@ -1,3 +1,4 @@
import json
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
@ -5,6 +6,7 @@ import torch
from modules import shared from modules import shared
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
css = f.read() css = f.read()
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
@ -14,7 +16,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read() chat_js = f.read()
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '🔄'
delete_symbol = '🗑️' delete_symbol = '🗑️'
save_symbol = '💾' save_symbol = '💾'
@ -30,17 +32,105 @@ theme = gr.themes.Default(
def list_model_elements(): def list_model_elements():
elements = ['loader', 'cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'trust_remote_code', 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', 'wbits', 'groupsize', 'model_type', 'pre_layer', 'triton', 'desc_act', 'no_inject_fused_attention', 'no_inject_fused_mlp', 'no_use_cuda_fp16', 'threads', 'n_batch', 'no_mmap', 'mlock', 'n_gpu_layers', 'n_ctx', 'llama_cpp_seed', 'gpu_split', 'max_seq_len', 'compress_pos_emb'] elements = [
'loader',
'cpu_memory',
'auto_devices',
'disk',
'cpu',
'bf16',
'load_in_8bit',
'trust_remote_code',
'load_in_4bit',
'compute_dtype',
'quant_type',
'use_double_quant',
'wbits',
'groupsize',
'model_type',
'pre_layer',
'triton',
'desc_act',
'no_inject_fused_attention',
'no_inject_fused_mlp',
'no_use_cuda_fp16',
'threads',
'n_batch',
'no_mmap',
'mlock',
'n_gpu_layers',
'n_ctx',
'llama_cpp_seed',
'gpu_split',
'max_seq_len',
'compress_pos_emb',
'alpha_value'
]
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}') elements.append(f'gpu_memory_{i}')
return elements return elements
def list_interface_input_elements(chat=False): def list_interface_input_elements():
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a'] elements = [
if chat: 'preset_menu',
elements += ['name1', 'name2', 'greeting', 'context', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command'] 'max_new_tokens',
'seed',
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'add_bos_token',
'ban_eos_token',
'truncation_length',
'custom_stopping_strings',
'skip_special_tokens',
'stream',
'tfs',
'top_a',
]
if shared.args.chat:
elements += [
'character_menu',
'history',
'name1',
'name2',
'greeting',
'context',
'chat_generation_attempts',
'stop_at_newline',
'mode',
'instruction_template',
'name1_instruct',
'name2_instruct',
'context_instruct',
'turn_template',
'chat_style',
'chat-instruct_command',
]
else:
elements.append('textbox')
if not shared.args.notebook:
elements.append('output_textbox')
elements += list_model_elements() elements += list_model_elements()
return elements return elements
@ -48,10 +138,15 @@ def list_interface_input_elements(chat=False):
def gather_interface_values(*args): def gather_interface_values(*args):
output = {} output = {}
for i, element in enumerate(shared.input_elements): for i, element in enumerate(list_interface_input_elements()):
output[element] = args[i] output[element] = args[i]
if not shared.args.multi_user:
shared.persistent_interface_state = output shared.persistent_interface_state = output
Path('logs').mkdir(exist_ok=True)
with open(Path(f'logs/session_{shared.get_mode()}_autosave.json'), 'w') as f:
f.write(json.dumps(output, indent=4))
return output return output
@ -59,11 +154,12 @@ def apply_interface_values(state, use_persistent=False):
if use_persistent: if use_persistent:
state = shared.persistent_interface_state state = shared.persistent_interface_state
elements = list_interface_input_elements(chat=shared.is_chat()) elements = list_interface_input_elements()
if len(state) == 0: if len(state) == 0:
return [gr.update() for k in elements] # Dummy, do nothing return [gr.update() for k in elements] # Dummy, do nothing
else: else:
return [state[k] if k in state else gr.update() for k in elements] ans = [state[k] if k in state else gr.update() for k in elements]
return ans
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):
@ -92,6 +188,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
inputs=[], inputs=[],
outputs=[refresh_component] outputs=[refresh_component]
) )
return refresh_button return refresh_button

View File

@ -7,6 +7,14 @@ from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
# Helper function to get multiple values from shared.gradio
def gradio(*keys):
if len(keys) == 1 and type(keys[0]) is list:
keys = keys[0]
return [shared.gradio[k] for k in keys]
def save_file(fname, contents): def save_file(fname, contents):
if fname == '': if fname == '':
logger.error('File name is empty!') logger.error('File name is empty!')
@ -111,3 +119,8 @@ def get_datasets(path: str, ext: str):
def get_available_chat_styles(): def get_available_chat_styles():
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys) return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
def get_available_sessions():
items = sorted(set(k.stem for k in Path('logs').glob(f'session_{shared.get_mode()}*')), key=natural_keys, reverse=True)
return [item for item in items if 'autosave' in item] + [item for item in items if 'autosave' not in item]

View File

@ -23,5 +23,5 @@ llama-cpp-python==0.1.66; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.66/llama_cpp_python-0.1.66-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.66/llama_cpp_python-0.1.66-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/jllllll/exllama/releases/download/0.0.4/exllama-0.0.4+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/exllama/releases/download/0.0.5/exllama-0.0.5+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/jllllll/exllama/releases/download/0.0.4/exllama-0.0.4+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/jllllll/exllama/releases/download/0.0.5/exllama-0.0.5+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

485
server.py
View File

@ -49,6 +49,7 @@ from modules.text_generation import (
get_encoded_length, get_encoded_length,
stop_everything_event stop_everything_event
) )
from modules.utils import gradio
def load_model_wrapper(selected_model, loader, autoload=False): def load_model_wrapper(selected_model, loader, autoload=False):
@ -225,6 +226,7 @@ def create_model_menus():
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb) shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb)
shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=8, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value)
with gr.Column(): with gr.Column():
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
@ -257,40 +259,43 @@ def create_model_menus():
with gr.Row(): with gr.Row():
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
shared.gradio['loader'].change(loaders.make_loader_params_visible, shared.gradio['loader'], [shared.gradio[k] for k in loaders.get_all_params()]) shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()))
# In this event handler, the interface state is read and updated # In this event handler, the interface state is read and updated
# with the model defaults (if any), and then the model is loaded # with the model defaults (if any), and then the model is loaded
# unless "autoload_model" is unchecked # unless "autoload_model" is unchecked
shared.gradio['model_menu'].change( shared.gradio['model_menu'].change(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
apply_model_settings_to_state, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then( apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then( ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, gradio('interface_state'), None).then(
load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'loader', 'autoload_model']], shared.gradio['model_status'], show_progress=False) load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False)
load.click( load.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, gradio('interface_state'), None).then(
partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']], shared.gradio['model_status'], show_progress=False) partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).then(
lambda: shared.lora_names, None, gradio('lora_menu'))
unload.click( unload.click(
unload_model, None, None).then( unload_model, None, None).then(
lambda: "Model unloaded", None, shared.gradio['model_status']) lambda: "Model unloaded", None, gradio('model_status')).then(
lambda: shared.lora_names, None, gradio('lora_menu'))
reload.click( reload.click(
unload_model, None, None).then( unload_model, None, None).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, gradio('interface_state'), None).then(
partial(load_model_wrapper, autoload=True), [shared.gradio[k] for k in ['model_menu', 'loader']], shared.gradio['model_status'], show_progress=False) partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).then(
lambda: shared.lora_names, None, gradio('lora_menu'))
save_settings.click( save_settings.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False) save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=True) shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu'), gradio('model_status'), show_progress=True)
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), load)
def create_chat_settings_menus(): def create_chat_settings_menus():
@ -327,23 +332,75 @@ def create_settings_menus(default_preset):
gr.Markdown('Main parameters') gr.Markdown('Main parameters')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.') shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.') shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.') shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff', info='In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0.') shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff')
shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff', info='In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0.') shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff')
with gr.Column(): with gr.Column():
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.') shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.') shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.') shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length')
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs') shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a') shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=generate_params['top_a'], step=0.01, label='top_a')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)') with gr.Accordion("Learn more", open=False):
gr.Markdown("""
Not all parameters are used by all loaders. See [this page](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md) for details.
For a technical description of the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference.
The best presets, according to the [Preset Arena](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md) experiment, are:
* Instruction following:
1) Divine Intellect
2) Big O
3) simple-1
4) Space Alien
5) StarChat
6) Titanic
7) tfs-with-top-a
8) Asterism
9) Contrastive Search
* Chat:
1) Midnight Enigma
2) Yara
3) Shortwave
4) Kobold-Godlike
### Temperature
Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.
### top_p
If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
### top_k
Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
### typical_p
If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
### epsilon_cutoff
In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. Should be used with top_p, top_k, and eta_cutoff set to 0.
### eta_cutoff
In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0.
### repetition_penalty
Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
### repetition_penalty_range
The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
### encoder_repetition_penalty
Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
### no_repeat_ngram_size
If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
### min_length
Minimum generation length in tokens.
### penalty_alpha
Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.
""", elem_classes="markdown")
with gr.Column(): with gr.Column():
create_chat_settings_menus() create_chat_settings_menus()
@ -351,7 +408,7 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown('Contrastive search') gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
gr.Markdown('Beam search') gr.Markdown('Beam search')
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
@ -376,7 +433,7 @@ def create_settings_menus(default_preset):
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a'))
def create_file_saving_menus(): def create_file_saving_menus():
@ -415,39 +472,80 @@ def create_file_saving_menus():
def create_file_saving_event_handlers(): def create_file_saving_event_handlers():
shared.gradio['save_confirm'].click( shared.gradio['save_confirm'].click(
lambda x, y, z: utils.save_file(x + y, z), [shared.gradio[k] for k in ['save_root', 'save_filename', 'save_contents']], None).then( lambda x, y, z: utils.save_file(x + y, z), gradio('save_root', 'save_filename', 'save_contents'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['file_saver']) lambda: gr.update(visible=False), None, gradio('file_saver'))
shared.gradio['delete_confirm'].click( shared.gradio['delete_confirm'].click(
lambda x, y: utils.delete_file(x + y), [shared.gradio[k] for k in ['delete_root', 'delete_filename']], None).then( lambda x, y: utils.delete_file(x + y), gradio('delete_root', 'delete_filename'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['file_deleter']) lambda: gr.update(visible=False), None, gradio('file_deleter'))
shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['file_deleter']) shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter'))
shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['file_saver']) shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver'))
if shared.is_chat(): if shared.is_chat():
shared.gradio['save_character_confirm'].click( shared.gradio['save_character_confirm'].click(
chat.save_character, [shared.gradio[k] for k in ['name2', 'greeting', 'context', 'character_picture', 'save_character_filename']], None).then( chat.save_character, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['character_saver']) lambda: gr.update(visible=False), None, gradio('character_saver'))
shared.gradio['delete_character_confirm'].click( shared.gradio['delete_character_confirm'].click(
chat.delete_character, shared.gradio['character_menu'], None).then( chat.delete_character, gradio('character_menu'), None).then(
lambda: gr.update(visible=False), None, shared.gradio['character_deleter']).then( lambda: gr.update(visible=False), None, gradio('character_deleter')).then(
lambda: gr.update(choices=utils.get_available_characters()), outputs=shared.gradio['character_menu']) lambda: gr.update(choices=utils.get_available_characters()), None, gradio('character_menu'))
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['character_saver']) shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'))
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, shared.gradio['character_deleter']) shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'))
shared.gradio['save_preset'].click( shared.gradio['save_preset'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then( presets.generate_preset_yaml, gradio('interface_state'), gradio('save_contents')).then(
lambda: 'presets/', None, shared.gradio['save_root']).then( lambda: 'presets/', None, gradio('save_root')).then(
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then( lambda: 'My Preset.yaml', None, gradio('save_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_preset'].click( shared.gradio['delete_preset'].click(
lambda x: f'{x}.yaml', shared.gradio['preset_menu'], shared.gradio['delete_filename']).then( lambda x: f'{x}.yaml', gradio('preset_menu'), gradio('delete_filename')).then(
lambda: 'presets/', None, shared.gradio['delete_root']).then( lambda: 'presets/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_deleter']) lambda: gr.update(visible=True), None, gradio('file_deleter'))
if not shared.args.multi_user:
def load_session(session, state):
with open(Path(f'logs/{session}.json'), 'r') as f:
state.update(json.loads(f.read()))
if shared.is_chat():
chat.save_persistent_history(state['history'], state['character_menu'], state['mode'])
return state
if shared.is_chat():
shared.gradio['save_session'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('save_contents')).then(
lambda: 'logs/', None, gradio('save_root')).then(
lambda x: f'session_{shared.get_mode()}_{x + "_" if x not in ["None", None, ""] else ""}{utils.current_time()}.json', gradio('character_menu'), gradio('save_filename')).then(
lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['session_menu'].change(
load_session, gradio('session_menu', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
else:
shared.gradio['save_session'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: json.dumps(x, indent=4), gradio('interface_state'), gradio('save_contents')).then(
lambda: 'logs/', None, gradio('save_root')).then(
lambda: f'session_{shared.get_mode()}_{utils.current_time()}.json', None, gradio('save_filename')).then(
lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['session_menu'].change(
load_session, gradio('session_menu', 'interface_state'), gradio('interface_state')).then(
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False)
shared.gradio['delete_session'].click(
lambda x: f'{x}.json', gradio('session_menu'), gradio('delete_filename')).then(
lambda: 'logs/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, gradio('file_deleter'))
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
@ -511,13 +609,17 @@ def create_interface():
# Create chat mode interface # Create chat mode interface
if shared.is_chat(): if shared.is_chat():
shared.input_elements = ui.list_interface_input_elements(chat=True) shared.input_elements = ui.list_interface_input_elements()
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State() shared.gradio.update({
shared.gradio['dummy'] = gr.State() 'interface_state': gr.State({k: None for k in shared.input_elements}),
'Chat input': gr.State(),
'dummy': gr.State(),
'history': gr.State({'internal': [], 'visible': []}),
})
with gr.Tab('Text generation', elem_id='main'): with gr.Tab('Text generation', elem_id='main'):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat')) shared.gradio['display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': []}, shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop') shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
@ -553,7 +655,7 @@ def create_interface():
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown') shared.gradio['character_menu'] = gr.Dropdown(value='None', choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button')
shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button') shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button')
shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button') shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button')
@ -597,18 +699,25 @@ def create_interface():
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'], label='JSON File') shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'], label='JSON File')
shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)') shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)')
shared.gradio['Upload character'] = gr.Button(value='Submit', interactive=False) shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False)
with gr.Tab('TavernAI'): with gr.Tab('TavernAI'):
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'], label='TavernAI PNG File') with gr.Row():
shared.gradio['Upload tavern character'] = gr.Button(value='Submit', interactive=False) with gr.Column():
shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File', elem_id="upload_img_tavern")
shared.gradio['tavern_json'] = gr.State()
with gr.Column():
shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False)
shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=4, max_lines=4, label='Description', interactive=False)
shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False)
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
# Create notebook mode interface # Create notebook mode interface
elif shared.args.notebook: elif shared.args.notebook:
shared.input_elements = ui.list_interface_input_elements(chat=False) shared.input_elements = ui.list_interface_input_elements()
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['last_input'] = gr.State('') shared.gradio['last_input'] = gr.State('')
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@ -647,7 +756,7 @@ def create_interface():
# Create default mode interface # Create default mode interface
else: else:
shared.input_elements = ui.list_interface_input_elements(chat=False) shared.input_elements = ui.list_interface_input_elements()
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['last_input'] = gr.State('') shared.gradio['last_input'] = gr.State('')
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@ -691,8 +800,8 @@ def create_interface():
with gr.Tab("Training", elem_id="training-tab"): with gr.Tab("Training", elem_id="training-tab"):
training.create_train_interface() training.create_train_interface()
# Interface mode tab # Session tab
with gr.Tab("Interface mode", elem_id="interface-mode"): with gr.Tab("Session", elem_id="session-tab"):
modes = ["default", "notebook", "chat"] modes = ["default", "notebook", "chat"]
current_mode = "default" current_mode = "default"
for mode in modes[1:]: for mode in modes[1:]:
@ -705,9 +814,12 @@ def create_interface():
bool_active = [k for k in bool_list if vars(shared.args)[k]] bool_active = [k for k in bool_list if vars(shared.args)[k]]
with gr.Row(): with gr.Row():
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", elem_classes="small-button") with gr.Column():
shared.gradio['toggle_dark_mode'] = gr.Button('Toggle dark/light mode', elem_classes="small-button") with gr.Row():
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode", elem_classes='slim-dropdown')
shared.gradio['reset_interface'] = gr.Button("Apply and restart", elem_classes="small-button", variant="primary")
shared.gradio['toggle_dark_mode'] = gr.Button('Toggle 💡', elem_classes="small-button")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -716,212 +828,237 @@ def create_interface():
with gr.Column(): with gr.Column():
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table') shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags", elem_classes='checkboxgroup-table')
with gr.Column():
if not shared.args.multi_user:
with gr.Row(): with gr.Row():
extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.') shared.gradio['session_menu'] = gr.Dropdown(choices=utils.get_available_sessions(), value='None', label='Session', elem_classes='slim-dropdown', info='When saving a session, make sure to keep the initial part of the filename (session_chat, session_notebook, or session_default), otherwise it will not appear on this list afterwards.')
extension_install = gr.Button('Install or update', elem_classes="small-button") ui.create_refresh_button(shared.gradio['session_menu'], lambda: None, lambda: {'choices': utils.get_available_sessions()}, ['refresh-button'])
shared.gradio['save_session'] = gr.Button('💾', elem_classes=['refresh-button'])
shared.gradio['delete_session'] = gr.Button('🗑️', elem_classes=['refresh-button'])
extension_name = gr.Textbox(lines=1, label='Install or update an extension', info='Enter the GitHub URL below and press Enter. For a list of extensions, see: https://github.com/oobabooga/text-generation-webui-extensions ⚠️ WARNING ⚠️ : extensions can execute arbitrary code. Make sure to inspect their source code before activating them.')
extension_status = gr.Markdown() extension_status = gr.Markdown()
extension_install.click( extension_name.submit(
clone_or_pull_repository, extension_name, extension_status, show_progress=False).then( clone_or_pull_repository, extension_name, extension_status, show_progress=False).then(
lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), outputs=shared.gradio['extensions_menu']) lambda: gr.update(choices=utils.get_available_extensions(), value=shared.args.extensions), None, gradio('extensions_menu'))
# Reset interface event # Reset interface event
shared.gradio['reset_interface'].click( shared.gradio['reset_interface'].click(
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then( set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}') lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}') shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')
# chat mode event handlers # chat mode event handlers
if shared.is_chat(): if shared.is_chat():
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'start_with', 'interface_state']] shared.input_params = gradio('Chat input', 'start_with', 'interface_state')
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = gradio('Clear history-confirm', 'Clear history', 'Clear history-cancel')
shared.reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']] shared.reload_inputs = gradio('history', 'name1', 'name2', 'mode', 'chat_style')
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['Generate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['textbox'].submit( gen_events.append(shared.gradio['textbox'].submit(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( lambda x: (x, ''), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then(
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['Regenerate'].click( gen_events.append(shared.gradio['Regenerate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, shared.gradio['display'], show_progress=False).then( partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['Continue'].click( gen_events.append(shared.gradio['Continue'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, shared.gradio['display'], show_progress=False).then( partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, gradio('display', 'history'), show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
gen_events.append(shared.gradio['Impersonate'].click( gen_events.append(shared.gradio['Impersonate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: x, shared.gradio['textbox'], shared.gradio['Chat input'], show_progress=False).then( lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False).then( chat.impersonate_wrapper, shared.input_params, gradio('textbox'), show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
) )
shared.gradio['Replace last reply'].click( shared.gradio['Replace last reply'].click(
chat.replace_last_reply, shared.gradio['textbox'], None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: '', None, shared.gradio['textbox'], show_progress=False).then( chat.replace_last_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( lambda: '', None, gradio('textbox'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Send dummy message'].click( shared.gradio['Send dummy message'].click(
chat.send_dummy_message, shared.gradio['textbox'], None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: '', None, shared.gradio['textbox'], show_progress=False).then( chat.send_dummy_message, gradio('textbox', 'interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( lambda: '', None, gradio('textbox'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Send dummy reply'].click( shared.gradio['Send dummy reply'].click(
chat.send_dummy_reply, shared.gradio['textbox'], None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: '', None, shared.gradio['textbox'], show_progress=False).then( chat.send_dummy_reply, gradio('textbox', 'interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( lambda: '', None, gradio('textbox'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click( shared.gradio['Clear history-confirm'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['greeting', 'mode']], None).then( chat.clear_chat_log, gradio('interface_state'), gradio('history')).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then( chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['Remove last'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.remove_last_message, gradio('history'), gradio('textbox', 'history'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, gradio('display')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None)
shared.gradio['character_menu'].change(
partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.load_persistent_history, gradio('interface_state'), gradio('history')).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['Stop'].click( shared.gradio['Stop'].click(
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then( stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['mode'].change( shared.gradio['mode'].change(
lambda x: gr.update(visible=x != 'instruct'), shared.gradio['mode'], shared.gradio['chat_style'], show_progress=False).then( lambda x: gr.update(visible=x != 'instruct'), gradio('mode'), gradio('chat_style'), show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, shared.gradio['display']) shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['instruction_template'].change( shared.gradio['instruction_template'].change(
partial(chat.load_character, instruct=True), [shared.gradio[k] for k in ['instruction_template', 'name1_instruct', 'name2_instruct']], [shared.gradio[k] for k in ['name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template']]) partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template'))
shared.gradio['upload_chat_history'].upload( shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then( chat.load_history, gradio('upload_chat_history', 'history'), gradio('history')).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(
chat.remove_last_message, None, shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
# Save/delete a character # Save/delete a character
shared.gradio['save_character'].click( shared.gradio['save_character'].click(
lambda x: x, shared.gradio['name2'], shared.gradio['save_character_filename']).then( lambda x: x, gradio('name2'), gradio('save_character_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['character_saver']) lambda: gr.update(visible=True), None, gradio('character_saver'))
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, shared.gradio['character_deleter']) shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
shared.gradio['save_template'].click( shared.gradio['save_template'].click(
lambda: 'My Template.yaml', None, shared.gradio['save_filename']).then( lambda: 'My Template.yaml', None, gradio('save_filename')).then(
lambda: 'characters/instruction-following/', None, shared.gradio['save_root']).then( lambda: 'characters/instruction-following/', None, gradio('save_root')).then(
chat.generate_instruction_template_yaml, [shared.gradio[k] for k in ['name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template']], shared.gradio['save_contents']).then( chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_template'].click( shared.gradio['delete_template'].click(
lambda x: f'{x}.yaml', shared.gradio['instruction_template'], shared.gradio['delete_filename']).then( lambda x: f'{x}.yaml', gradio('instruction_template'), gradio('delete_filename')).then(
lambda: 'characters/instruction-following/', None, shared.gradio['delete_root']).then( lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_deleter']) lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True, user_request=True), shared.gradio['mode'], shared.gradio['download']) shared.gradio['download_button'].click(chat.save_history, gradio('history'), gradio('download'))
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Submit character'].click(chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu'))
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, [shared.gradio['Upload character']]) shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, [shared.gradio['Upload character']]) shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
shared.gradio['character_menu'].change( shared.gradio['Submit tavern character'].click(chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu'))
partial(chat.load_character, instruct=False), [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy']]).then( shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
chat.redraw_html, shared.reload_inputs, shared.gradio['display']) shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False)
shared.gradio['Upload tavern character'].click(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['upload_img_tavern'].upload(lambda: gr.update(interactive=True), None, [shared.gradio['Upload tavern character']])
shared.gradio['upload_img_tavern'].clear(lambda: gr.update(interactive=False), None, [shared.gradio['Upload tavern character']])
shared.gradio['your_picture'].change( shared.gradio['your_picture'].change(
chat.upload_your_profile_picture, shared.gradio['your_picture'], None).then( chat.upload_your_profile_picture, gradio('your_picture'), None).then(
partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, shared.gradio['display']) partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, gradio('display'))
# notebook/default modes event handlers # notebook/default modes event handlers
else: else:
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] shared.input_params = gradio('textbox', 'interface_state')
if shared.args.notebook: if shared.args.notebook:
output_params = [shared.gradio[k] for k in ['textbox', 'html']] output_params = gradio('textbox', 'html')
else: else:
output_params = [shared.gradio[k] for k in ['output_textbox', 'html']] output_params = gradio('output_textbox', 'html')
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['Generate'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then( lambda x: x, gradio('textbox'), gradio('last_input')).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
) )
gen_events.append(shared.gradio['textbox'].submit( gen_events.append(shared.gradio['textbox'].submit(
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then( lambda x: x, gradio('textbox'), gradio('last_input')).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
) )
if shared.args.notebook: if shared.args.notebook:
shared.gradio['Undo'].click(lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False) shared.gradio['Undo'].click(lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False)
shared.gradio['markdown_render'].click(lambda x: x, shared.gradio['textbox'], shared.gradio['markdown'], queue=False) shared.gradio['markdown_render'].click(lambda x: x, gradio('textbox'), gradio('markdown'), queue=False)
gen_events.append(shared.gradio['Regenerate'].click( gen_events.append(shared.gradio['Regenerate'].click(
lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then( lambda x: x, gradio('last_input'), gradio('textbox'), show_progress=False).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then( generate_reply_wrapper, shared.input_params, output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
) )
else: else:
shared.gradio['markdown_render'].click(lambda x: x, shared.gradio['output_textbox'], shared.gradio['markdown'], queue=False) shared.gradio['markdown_render'].click(lambda x: x, gradio('output_textbox'), gradio('markdown'), queue=False)
gen_events.append(shared.gradio['Continue'].click( gen_events.append(shared.gradio['Continue'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then( generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}") lambda: None, None, None, _js=f"() => {{{audio_notification_js}}}")
# lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") # lambda: None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
) )
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, gradio('prompt_menu'), gradio('textbox'), show_progress=False)
shared.gradio['save_prompt'].click( shared.gradio['save_prompt'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then( lambda x: x, gradio('textbox'), gradio('save_contents')).then(
lambda: 'prompts/', None, shared.gradio['save_root']).then( lambda: 'prompts/', None, gradio('save_root')).then(
lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then( lambda: utils.current_time() + '.txt', None, gradio('save_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver']) lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_prompt'].click( shared.gradio['delete_prompt'].click(
lambda: 'prompts/', None, shared.gradio['delete_root']).then( lambda: 'prompts/', None, gradio('delete_root')).then(
lambda x: x + '.txt', shared.gradio['prompt_menu'], shared.gradio['delete_filename']).then( lambda x: x + '.txt', gradio('prompt_menu'), gradio('delete_filename')).then(
lambda: gr.update(visible=True), None, shared.gradio['file_deleter']) lambda: gr.update(visible=True), None, gradio('file_deleter'))
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, gradio('textbox'), gradio('status'), show_progress=False)
create_file_saving_event_handlers() create_file_saving_event_handlers()
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}") shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False)
if shared.settings['dark_theme']: if shared.settings['dark_theme']:
shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')") shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) if shared.is_chat():
shared.gradio['interface'].load(chat.redraw_html, shared.reload_inputs, gradio('display'))
# Extensions tabs # Extensions tabs
extensions_module.create_extensions_tabs() extensions_module.create_extensions_tabs()
@ -1018,7 +1155,11 @@ if __name__ == "__main__":
if shared.args.lora: if shared.args.lora:
add_lora_to_model(shared.args.lora) add_lora_to_model(shared.args.lora)
# Force a character to be loaded # Forcing some events to be triggered on page load
shared.persistent_interface_state.update({
'loader': shared.args.loader or 'Transformers',
})
if shared.is_chat(): if shared.is_chat():
shared.persistent_interface_state.update({ shared.persistent_interface_state.update({
'mode': shared.settings['mode'], 'mode': shared.settings['mode'],
@ -1026,11 +1167,11 @@ if __name__ == "__main__":
'instruction_template': shared.settings['instruction_template'] 'instruction_template': shared.settings['instruction_template']
}) })
shared.persistent_interface_state.update({ if Path("cache/pfp_character.png").exists():
'loader': shared.args.loader or 'Transformers', Path("cache/pfp_character.png").unlink()
})
shared.generation_lock = Lock() shared.generation_lock = Lock()
# Launch the web UI # Launch the web UI
create_interface() create_interface()
while True: while True: