diff --git a/download-model.py b/download-model.py index a48a1b8c..fc17e716 100644 --- a/download-model.py +++ b/download-model.py @@ -20,17 +20,6 @@ import tqdm from tqdm.contrib.concurrent import thread_map -parser = argparse.ArgumentParser() -parser.add_argument('MODEL', type=str, default=None, nargs='?') -parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') -parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') -parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') -parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') -parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') -parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') -args = parser.parse_args() - - def select_model_from_default_options(): models = { "OPT 6.7B": ("facebook", "opt-6.7b", "main"), @@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder): if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('MODEL', type=str, default=None, nargs='?') + parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') + parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') + parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') + parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') + parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') + args = parser.parse_args() + branch = args.branch model = args.MODEL if model is None: diff --git a/server.py b/server.py index cbfbd241..0108d875 100644 --- a/server.py +++ b/server.py @@ -2,11 +2,14 @@ import os os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' +import importlib import io import json +import os import re import sys import time +import traceback import zipfile from datetime import datetime from pathlib import Path @@ -172,6 +175,34 @@ def create_prompt_menus(): shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) +def download_model_wrapper(repo_id): + try: + downloader = importlib.import_module("download-model") + + model = repo_id + branch = "main" + check = False + + yield("Cleaning up the model/branch names") + model, branch = downloader.sanitize_model_and_branch_names(model, branch) + + yield("Getting the download links from Hugging Face") + links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False) + + yield("Getting the output folder") + output_folder = downloader.get_output_folder(model, branch, is_lora) + + if check: + yield("Checking previously downloaded files") + downloader.check_model_files(model, branch, links, sha256, output_folder) + else: + yield("Downloading files") + downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) + yield("Done!") + except: + yield traceback.format_exc() + + def create_model_menus(): with gr.Row(): with gr.Column(): @@ -182,9 +213,21 @@ def create_model_menus(): with gr.Row(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", + info="Enter Hugging Face username/model path e.g: facebook/galactica-125m") + with gr.Column(): + shared.gradio['download_button'] = gr.Button("Download", show_progress=True) + shared.gradio['download_status'] = gr.Markdown() + with gr.Column(): + pass shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False) def create_settings_menus(default_preset):