Add SD checkpoint selection in sd_api_pictures (#2872)

This commit is contained in:
GuizzyQC 2023-06-27 16:29:27 -04:00 committed by GitHub
parent 7611978f7b
commit d7a7f7896b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,7 +34,10 @@ params = {
'seed': -1, 'seed': -1,
'sampler_name': 'DDIM', 'sampler_name': 'DDIM',
'steps': 32, 'steps': 32,
'cfg_scale': 7 'cfg_scale': 7,
'sd_checkpoint' : ' ',
'checkpoint_list' : [" "]
} }
@ -265,6 +268,28 @@ def custom_css():
path_to_css = Path(__file__).parent.resolve() / 'style.css' path_to_css = Path(__file__).parent.resolve() / 'style.css'
return open(path_to_css, 'r').read() return open(path_to_css, 'r').read()
def get_checkpoints():
global params
try:
models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
options = requests.get(url=f'{params["address"]}/sdapi/v1/options')
options_json = options.json()
params['sd_checkpoint'] = options_json['sd_model_checkpoint']
params['checkpoint_list'] = [result["title"] for result in models.json()]
except:
params['sd_checkpoint'] = ""
params['checkpoint_list'] = []
return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint'])
def load_checkpoint(checkpoint):
payload = {
"sd_model_checkpoint": checkpoint
}
requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload)
def ui(): def ui():
@ -281,6 +306,9 @@ def ui():
force_pic = gr.Button("Force the picture response") force_pic = gr.Button("Force the picture response")
suppr_pic = gr.Button("Suppress the picture response") suppr_pic = gr.Button("Suppress the picture response")
with gr.Row():
checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value")
update_checkpoints = gr.Button("Get list of checkpoints")
with gr.Accordion("Generation parameters", open=False): with gr.Accordion("Generation parameters", open=False):
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
@ -322,6 +350,9 @@ def ui():
hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None) hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None)
enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None) enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None)
enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options) enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options)
update_checkpoints.click(get_checkpoints, None, checkpoint)
checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None)
checkpoint.change(load_checkpoint, checkpoint, None)
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
steps.change(lambda x: params.update({"steps": x}), steps, None) steps.change(lambda x: params.update({"steps": x}), steps, None)