mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-15 23:01:04 +01:00
387 lines
16 KiB
Python
387 lines
16 KiB
Python
import base64
|
|
import io
|
|
import re
|
|
import time
|
|
from datetime import date
|
|
from pathlib import Path
|
|
|
|
import gradio as gr
|
|
import requests
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from modules import shared
|
|
from modules.models import reload_model, unload_model
|
|
from modules.ui import create_refresh_button
|
|
|
|
torch._C._jit_set_profiling_mode(False)
|
|
|
|
# parameters which can be customized in settings.json of webui
|
|
params = {
|
|
'address': 'http://127.0.0.1:7860',
|
|
'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
|
|
'manage_VRAM': False,
|
|
'save_img': False,
|
|
'SD_model': 'NeverEndingDream', # not used right now
|
|
'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
|
|
'negative_prompt': '(worst quality, low quality:1.3)',
|
|
'width': 512,
|
|
'height': 512,
|
|
'denoising_strength': 0.61,
|
|
'restore_faces': False,
|
|
'enable_hr': False,
|
|
'hr_upscaler': 'ESRGAN_4x',
|
|
'hr_scale': '1.0',
|
|
'seed': -1,
|
|
'sampler_name': 'DPM++ 2M Karras',
|
|
'steps': 32,
|
|
'cfg_scale': 7,
|
|
'textgen_prefix': 'Please provide a detailed and vivid description of [subject]',
|
|
'sd_checkpoint': ' ',
|
|
'checkpoint_list': [" "]
|
|
}
|
|
|
|
|
|
def give_VRAM_priority(actor):
|
|
global shared, params
|
|
|
|
if actor == 'SD':
|
|
unload_model()
|
|
print("Requesting Auto1111 to re-load last checkpoint used...")
|
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
|
response.raise_for_status()
|
|
|
|
elif actor == 'LLM':
|
|
print("Requesting Auto1111 to vacate VRAM...")
|
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
|
response.raise_for_status()
|
|
reload_model()
|
|
|
|
elif actor == 'set':
|
|
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
|
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
|
response.raise_for_status()
|
|
|
|
elif actor == 'reset':
|
|
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
|
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
|
response.raise_for_status()
|
|
|
|
else:
|
|
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
|
|
|
|
response.raise_for_status()
|
|
del response
|
|
|
|
|
|
if params['manage_VRAM']:
|
|
give_VRAM_priority('set')
|
|
|
|
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
|
|
|
picture_response = False # specifies if the next model response should appear as a picture
|
|
|
|
|
|
def remove_surrounded_chars(string):
|
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
|
|
|
|
|
def triggers_are_in(string):
|
|
string = remove_surrounded_chars(string)
|
|
# regex searches for send|main|message|me (at the end of the word) followed by
|
|
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
|
|
# (?aims) are regex parser flags
|
|
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
|
|
|
|
|
def state_modifier(state):
|
|
if picture_response:
|
|
state['stream'] = False
|
|
|
|
return state
|
|
|
|
|
|
def input_modifier(string):
|
|
"""
|
|
This function is applied to your text inputs before
|
|
they are fed into the model.
|
|
"""
|
|
|
|
global params
|
|
|
|
if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
|
|
return string
|
|
|
|
if triggers_are_in(string): # if we're in it, check for trigger words
|
|
toggle_generation(True)
|
|
string = string.lower()
|
|
if "of" in string:
|
|
subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
|
|
string = params['textgen_prefix'].replace("[subject]", subject)
|
|
else:
|
|
string = params['textgen_prefix'].replace("[subject]", "your appearance, your surroundings and what you are doing right now")
|
|
|
|
return string
|
|
|
|
# Get and save the Stable Diffusion-generated picture
|
|
def get_SD_pictures(description, character):
|
|
|
|
global params
|
|
|
|
if params['manage_VRAM']:
|
|
give_VRAM_priority('SD')
|
|
|
|
description = re.sub('<audio.*?</audio>', ' ', description)
|
|
description = f"({description}:1)"
|
|
|
|
payload = {
|
|
"prompt": params['prompt_prefix'] + description,
|
|
"seed": params['seed'],
|
|
"sampler_name": params['sampler_name'],
|
|
"enable_hr": params['enable_hr'],
|
|
"hr_scale": params['hr_scale'],
|
|
"hr_upscaler": params['hr_upscaler'],
|
|
"denoising_strength": params['denoising_strength'],
|
|
"steps": params['steps'],
|
|
"cfg_scale": params['cfg_scale'],
|
|
"width": params['width'],
|
|
"height": params['height'],
|
|
"restore_faces": params['restore_faces'],
|
|
"override_settings_restore_afterwards": True,
|
|
"negative_prompt": params['negative_prompt']
|
|
}
|
|
|
|
print(f'Prompting the image generator via the API on {params["address"]}...')
|
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
|
response.raise_for_status()
|
|
r = response.json()
|
|
|
|
visible_result = ""
|
|
for img_str in r['images']:
|
|
if params['save_img']:
|
|
img_data = base64.b64decode(img_str)
|
|
|
|
variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}'
|
|
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
|
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_file.as_posix(), 'wb') as f:
|
|
f.write(img_data)
|
|
|
|
visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
|
|
else:
|
|
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
|
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
|
image.thumbnail((300, 300))
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, format="JPEG")
|
|
buffered.seek(0)
|
|
image_bytes = buffered.getvalue()
|
|
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
|
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
|
|
|
if params['manage_VRAM']:
|
|
give_VRAM_priority('LLM')
|
|
|
|
return visible_result
|
|
|
|
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
|
# and replace it with 'text' for the purposes of logging?
|
|
def output_modifier(string, state):
|
|
"""
|
|
This function is applied to the model outputs.
|
|
"""
|
|
|
|
global picture_response, params
|
|
|
|
if not picture_response:
|
|
return string
|
|
|
|
string = remove_surrounded_chars(string)
|
|
string = string.replace('"', '')
|
|
string = string.replace('“', '')
|
|
string = string.replace('\n', ' ')
|
|
string = string.strip()
|
|
|
|
if string == '':
|
|
string = 'no viable description in reply, try regenerating'
|
|
return string
|
|
|
|
text = ""
|
|
if (params['mode'] < 2):
|
|
toggle_generation(False)
|
|
text = f'*Sends a picture which portrays: “{string}”*'
|
|
else:
|
|
text = string
|
|
|
|
string = get_SD_pictures(string, state['character_menu']) + "\n" + text
|
|
|
|
return string
|
|
|
|
|
|
def bot_prefix_modifier(string):
|
|
"""
|
|
This function is only applied in chat mode. It modifies
|
|
the prefix text for the Bot and can be used to bias its
|
|
behavior.
|
|
"""
|
|
|
|
return string
|
|
|
|
|
|
def toggle_generation(*args):
|
|
global picture_response, shared
|
|
|
|
if not args:
|
|
picture_response = not picture_response
|
|
else:
|
|
picture_response = args[0]
|
|
|
|
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
|
|
|
|
|
def filter_address(address):
|
|
address = address.strip()
|
|
# address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
|
|
address = re.sub('\/$', '', address) # remove trailing /s
|
|
if not address.startswith('http'):
|
|
address = 'http://' + address
|
|
return address
|
|
|
|
|
|
def SD_api_address_update(address):
|
|
global params
|
|
|
|
msg = "✔️ SD API is found on:"
|
|
address = filter_address(address)
|
|
params.update({"address": address})
|
|
try:
|
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
|
response.raise_for_status()
|
|
# r = response.json()
|
|
except:
|
|
msg = "❌ No SD API endpoint on:"
|
|
|
|
return gr.Textbox.update(label=msg)
|
|
|
|
|
|
def custom_css():
|
|
path_to_css = Path(__file__).parent.resolve() / 'style.css'
|
|
return open(path_to_css, 'r').read()
|
|
|
|
|
|
def get_checkpoints():
|
|
global params
|
|
|
|
try:
|
|
models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
|
options = requests.get(url=f'{params["address"]}/sdapi/v1/options')
|
|
options_json = options.json()
|
|
params['sd_checkpoint'] = options_json['sd_model_checkpoint']
|
|
params['checkpoint_list'] = [result["title"] for result in models.json()]
|
|
except:
|
|
params['sd_checkpoint'] = ""
|
|
params['checkpoint_list'] = []
|
|
|
|
return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint'])
|
|
|
|
|
|
def load_checkpoint(checkpoint):
|
|
payload = {
|
|
"sd_model_checkpoint": checkpoint
|
|
}
|
|
|
|
try:
|
|
requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload)
|
|
except:
|
|
pass
|
|
|
|
|
|
def get_samplers():
|
|
try:
|
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers')
|
|
response.raise_for_status()
|
|
samplers = [x["name"] for x in response.json()]
|
|
except:
|
|
samplers = []
|
|
|
|
return samplers
|
|
|
|
|
|
def ui():
|
|
|
|
# Gradio elements
|
|
# gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
|
|
with gr.Accordion("Parameters", open=True, elem_classes="SDAP"):
|
|
with gr.Row():
|
|
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
|
|
modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"]
|
|
mode = gr.Dropdown(modes_list, value=modes_list[params['mode']], label="Mode of operation", type="index")
|
|
with gr.Column(scale=1, min_width=300):
|
|
manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
|
|
save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
|
|
|
|
force_pic = gr.Button("Force the picture response")
|
|
suppr_pic = gr.Button("Suppress the picture response")
|
|
with gr.Row():
|
|
checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value")
|
|
update_checkpoints = gr.Button("Get list of checkpoints")
|
|
|
|
with gr.Accordion("Generation parameters", open=False):
|
|
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
|
textgen_prefix = gr.Textbox(placeholder=params['textgen_prefix'], value=params['textgen_prefix'], label='textgen prefix (type [subject] where the subject should be placed)')
|
|
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
|
with gr.Row():
|
|
with gr.Column():
|
|
width = gr.Slider(64, 2048, value=params['width'], step=64, label='Width')
|
|
height = gr.Slider(64, 2048, value=params['height'], step=64, label='Height')
|
|
with gr.Column(variant="compact", elem_id="sampler_col"):
|
|
with gr.Row(elem_id="sampler_row"):
|
|
sampler_name = gr.Dropdown(value=params['sampler_name'], allow_custom_value=True, label='Sampling method', elem_id="sampler_box")
|
|
create_refresh_button(sampler_name, lambda: None, lambda: {'choices': get_samplers()}, 'refresh-button')
|
|
steps = gr.Slider(1, 150, value=params['steps'], step=1, label="Sampling steps", elem_id="steps_box")
|
|
with gr.Row():
|
|
seed = gr.Number(label="Seed", value=params['seed'], elem_id="seed_box")
|
|
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
|
|
with gr.Column() as hr_options:
|
|
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
|
|
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
|
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
|
|
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
|
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
|
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
|
|
|
# Event functions to update the parameters in the backend
|
|
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
|
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
|
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
|
|
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
|
|
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
|
|
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
|
|
|
address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
|
|
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
|
|
textgen_prefix.change(lambda x: params.update({"textgen_prefix": x}), textgen_prefix, None)
|
|
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
|
|
width.change(lambda x: params.update({"width": x}), width, None)
|
|
height.change(lambda x: params.update({"height": x}), height, None)
|
|
hr_scale.change(lambda x: params.update({"hr_scale": x}), hr_scale, None)
|
|
denoising_strength.change(lambda x: params.update({"denoising_strength": x}), denoising_strength, None)
|
|
restore_faces.change(lambda x: params.update({"restore_faces": x}), restore_faces, None)
|
|
hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None)
|
|
enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None)
|
|
enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options)
|
|
update_checkpoints.click(get_checkpoints, None, checkpoint)
|
|
checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None)
|
|
checkpoint.change(load_checkpoint, checkpoint, None)
|
|
|
|
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
|
|
steps.change(lambda x: params.update({"steps": x}), steps, None)
|
|
seed.change(lambda x: params.update({"seed": x}), seed, None)
|
|
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
|
|
|
|
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
|
|
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)
|