From d7a7f7896b4903295e4331b61eb7b232bee26f0e Mon Sep 17 00:00:00 2001 From: GuizzyQC <86683381+GuizzyQC@users.noreply.github.com> Date: Tue, 27 Jun 2023 16:29:27 -0400 Subject: [PATCH] Add SD checkpoint selection in sd_api_pictures (#2872) --- extensions/sd_api_pictures/script.py | 33 +++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index 949531c9..a3684816 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -34,7 +34,10 @@ params = { 'seed': -1, 'sampler_name': 'DDIM', 'steps': 32, - 'cfg_scale': 7 + 'cfg_scale': 7, + 'sd_checkpoint' : ' ', + 'checkpoint_list' : [" "] + } @@ -265,6 +268,28 @@ def custom_css(): path_to_css = Path(__file__).parent.resolve() / 'style.css' return open(path_to_css, 'r').read() +def get_checkpoints(): + global params + + try: + models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') + options = requests.get(url=f'{params["address"]}/sdapi/v1/options') + options_json = options.json() + params['sd_checkpoint'] = options_json['sd_model_checkpoint'] + params['checkpoint_list'] = [result["title"] for result in models.json()] + except: + params['sd_checkpoint'] = "" + params['checkpoint_list'] = [] + + return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint']) + +def load_checkpoint(checkpoint): + + payload = { + "sd_model_checkpoint": checkpoint + } + + requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload) def ui(): @@ -281,6 +306,9 @@ def ui(): force_pic = gr.Button("Force the picture response") suppr_pic = gr.Button("Suppress the picture response") + with gr.Row(): + checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value") + update_checkpoints = gr.Button("Get list of checkpoints") with gr.Accordion("Generation parameters", open=False): prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') @@ -322,6 +350,9 @@ def ui(): hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None) enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None) enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options) + update_checkpoints.click(get_checkpoints, None, checkpoint) + checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None) + checkpoint.change(load_checkpoint, checkpoint, None) sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) steps.change(lambda x: params.update({"steps": x}), steps, None)