From 7436dd5b4aa48d990cefc02bca02d7ebcfafd26b Mon Sep 17 00:00:00 2001 From: Usama Kenway Date: Sun, 9 Apr 2023 16:11:43 +0500 Subject: [PATCH 1/4] download custom model menu (from hugging face) added in model tab --- server.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 740020ea..36cf57b4 100644 --- a/server.py +++ b/server.py @@ -10,7 +10,8 @@ import time import zipfile from datetime import datetime from pathlib import Path - +import os +import requests import gradio as gr from PIL import Image @@ -20,6 +21,7 @@ from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model from modules.text_generation import generate_reply, stop_everything_event +from huggingface_hub import HfApi # Loading custom settings settings_file = None @@ -172,6 +174,62 @@ def create_prompt_menus(): shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) +def download_model_wrapper(repo_id): + print(repo_id) + if repo_id == '': + print("Please enter a valid repo ID. This field cant be empty") + else: + try: + print('Downloading repo') + hf_api = HfApi() + # Get repo info + repo_info = hf_api.repo_info( + repo_id=repo_id, + repo_type="model", + revision="main" + ) + # create model and repo folder and check for lora + is_lora = False + for file in repo_info.siblings: + if 'adapter_model.bin' in file.rfilename: + is_lora = True + repo_dir_name = repo_id.replace("/", "--") + if is_lora is True: + models_dir = ".loras" + else: + models_dir = ".models" + if not os.path.exists(models_dir): + os.makedirs(models_dir) + repo_dir = os.path.join(models_dir, repo_dir_name) + if not os.path.exists(repo_dir): + os.makedirs(repo_dir) + + for sibling in repo_info.siblings: + filename = sibling.rfilename + url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" + download_path = os.path.join(repo_dir, filename) + response = requests.get(url, stream=True) + # Get the total file size from the content-length header + total_size = int(response.headers.get('content-length', 0)) + + # Download the file in chunks and print progress + with open(download_path, 'wb') as f: + downloaded_size = 0 + for data in response.iter_content(chunk_size=10000000): + downloaded_size += len(data) + f.write(data) + progress = downloaded_size * 100 // total_size + downloaded_size_mb = downloaded_size / (1024 * 1024) + total_size_mb = total_size / (1024 * 1024) + print(f"\rDownloading {filename}... {progress}% complete " + f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True) + print(f"\rDownloading {filename}... Complete!") + + print('Repo Downloaded') + except ValueError as e: + raise ValueError("Please enter a valid repo ID. Error: {}".format(e)) + + def create_model_menus(): with gr.Row(): with gr.Column(): @@ -182,6 +240,15 @@ def create_model_menus(): with gr.Row(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') + with gr.Row(): + with gr.Column(scale=0.5): + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", + info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'") + with gr.Row(): + with gr.Column(scale=0.5): + shared.gradio['download_button'] = gr.Button("Download", show_progress=True) + shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], + show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) From ebdf4c8c12ba8e284efca5f42e94389d09828f35 Mon Sep 17 00:00:00 2001 From: Usama Kenway Date: Sun, 9 Apr 2023 16:53:21 +0500 Subject: [PATCH 2/4] path fixed --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 36cf57b4..676c715a 100644 --- a/server.py +++ b/server.py @@ -195,9 +195,9 @@ def download_model_wrapper(repo_id): is_lora = True repo_dir_name = repo_id.replace("/", "--") if is_lora is True: - models_dir = ".loras" + models_dir = "loras" else: - models_dir = ".models" + models_dir = "models" if not os.path.exists(models_dir): os.makedirs(models_dir) repo_dir = os.path.join(models_dir, repo_dir_name) From 2c14df81a82bfbdaee5662812c80cc95a00cffdb Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 10 Apr 2023 11:36:39 -0300 Subject: [PATCH 3/4] Use download-model.py to download the model --- download-model.py | 22 +++++----- server.py | 100 ++++++++++++++++++---------------------------- 2 files changed, 50 insertions(+), 72 deletions(-) diff --git a/download-model.py b/download-model.py index a48a1b8c..fc17e716 100644 --- a/download-model.py +++ b/download-model.py @@ -20,17 +20,6 @@ import tqdm from tqdm.contrib.concurrent import thread_map -parser = argparse.ArgumentParser() -parser.add_argument('MODEL', type=str, default=None, nargs='?') -parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') -parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') -parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') -parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') -parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') -parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') -args = parser.parse_args() - - def select_model_from_default_options(): models = { "OPT 6.7B": ("facebook", "opt-6.7b", "main"), @@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder): if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('MODEL', type=str, default=None, nargs='?') + parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') + parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') + parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') + parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') + parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') + args = parser.parse_args() + branch = args.branch model = args.MODEL if model is None: diff --git a/server.py b/server.py index ae5a905f..5c0142c8 100644 --- a/server.py +++ b/server.py @@ -2,17 +2,21 @@ import os os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' +import importlib import io import json +import os import re import sys import time +import traceback import zipfile from datetime import datetime from pathlib import Path -import os -import requests + import gradio as gr +import requests +from huggingface_hub import HfApi from PIL import Image import modules.extensions as extensions_module @@ -21,7 +25,6 @@ from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model from modules.text_generation import generate_reply, stop_everything_event -from huggingface_hub import HfApi # Loading custom settings settings_file = None @@ -175,59 +178,31 @@ def create_prompt_menus(): def download_model_wrapper(repo_id): - print(repo_id) - if repo_id == '': - print("Please enter a valid repo ID. This field cant be empty") - else: - try: - print('Downloading repo') - hf_api = HfApi() - # Get repo info - repo_info = hf_api.repo_info( - repo_id=repo_id, - repo_type="model", - revision="main" - ) - # create model and repo folder and check for lora - is_lora = False - for file in repo_info.siblings: - if 'adapter_model.bin' in file.rfilename: - is_lora = True - repo_dir_name = repo_id.replace("/", "--") - if is_lora is True: - models_dir = "loras" - else: - models_dir = "models" - if not os.path.exists(models_dir): - os.makedirs(models_dir) - repo_dir = os.path.join(models_dir, repo_dir_name) - if not os.path.exists(repo_dir): - os.makedirs(repo_dir) + try: + downloader = importlib.import_module("download-model") - for sibling in repo_info.siblings: - filename = sibling.rfilename - url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" - download_path = os.path.join(repo_dir, filename) - response = requests.get(url, stream=True) - # Get the total file size from the content-length header - total_size = int(response.headers.get('content-length', 0)) + model = repo_id + branch = "main" + check = False - # Download the file in chunks and print progress - with open(download_path, 'wb') as f: - downloaded_size = 0 - for data in response.iter_content(chunk_size=10000000): - downloaded_size += len(data) - f.write(data) - progress = downloaded_size * 100 // total_size - downloaded_size_mb = downloaded_size / (1024 * 1024) - total_size_mb = total_size / (1024 * 1024) - print(f"\rDownloading {filename}... {progress}% complete " - f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True) - print(f"\rDownloading {filename}... Complete!") + yield("Cleaning up the model/branch names") + model, branch = downloader.sanitize_model_and_branch_names(model, branch) - print('Repo Downloaded') - except ValueError as e: - raise ValueError("Please enter a valid repo ID. Error: {}".format(e)) + yield("Getting the download links from Hugging Face") + links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False) + + yield("Getting the output folder") + output_folder = downloader.get_output_folder(model, branch, is_lora) + + if check: + yield("Checking previously downloaded files") + downloader.check_model_files(model, branch, links, sha256, output_folder) + else: + yield("Downloading files") + downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) + yield("Done!") + except: + yield traceback.format_exc() def create_model_menus(): @@ -241,17 +216,20 @@ def create_model_menus(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') with gr.Row(): - with gr.Column(scale=0.5): - shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", - info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'") - with gr.Row(): - with gr.Column(scale=0.5): - shared.gradio['download_button'] = gr.Button("Download", show_progress=True) - shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], - show_progress=True) + with gr.Column(): + with gr.Row(): + with gr.Column(): + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", + info="Enter Hugging Face username/model path e.g: facebook/galactica-125m") + with gr.Column(): + shared.gradio['download_button'] = gr.Button("Download", show_progress=True) + shared.gradio['download_status'] = gr.Markdown() + with gr.Column(): + pass shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False) def create_settings_menus(default_preset): From 11b23db8d45519fc3959555b79a01d21e7f99a84 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 10 Apr 2023 11:37:42 -0300 Subject: [PATCH 4/4] Remove unused imports --- server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server.py b/server.py index 5c0142c8..0108d875 100644 --- a/server.py +++ b/server.py @@ -15,8 +15,6 @@ from datetime import datetime from pathlib import Path import gradio as gr -import requests -from huggingface_hub import HfApi from PIL import Image import modules.extensions as extensions_module