Use download-model.py to download the model

This commit is contained in:
oobabooga 2023-04-10 11:36:39 -03:00
parent c6e9ba20a4
commit 2c14df81a8
2 changed files with 50 additions and 72 deletions

View File

@ -20,17 +20,6 @@ import tqdm
from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
def select_model_from_default_options():
models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
branch = args.branch
model = args.MODEL
if model is None:

View File

@ -2,17 +2,21 @@ import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import importlib
import io
import json
import os
import re
import sys
import time
import traceback
import zipfile
from datetime import datetime
from pathlib import Path
import os
import requests
import gradio as gr
import requests
from huggingface_hub import HfApi
from PIL import Image
import modules.extensions as extensions_module
@ -21,7 +25,6 @@ from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import generate_reply, stop_everything_event
from huggingface_hub import HfApi
# Loading custom settings
settings_file = None
@ -175,59 +178,31 @@ def create_prompt_menus():
def download_model_wrapper(repo_id):
print(repo_id)
if repo_id == '':
print("Please enter a valid repo ID. This field cant be empty")
else:
try:
print('Downloading repo')
hf_api = HfApi()
# Get repo info
repo_info = hf_api.repo_info(
repo_id=repo_id,
repo_type="model",
revision="main"
)
# create model and repo folder and check for lora
is_lora = False
for file in repo_info.siblings:
if 'adapter_model.bin' in file.rfilename:
is_lora = True
repo_dir_name = repo_id.replace("/", "--")
if is_lora is True:
models_dir = "loras"
downloader = importlib.import_module("download-model")
model = repo_id
branch = "main"
check = False
yield("Cleaning up the model/branch names")
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
yield("Getting the download links from Hugging Face")
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
yield("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora)
if check:
yield("Checking previously downloaded files")
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
models_dir = "models"
if not os.path.exists(models_dir):
os.makedirs(models_dir)
repo_dir = os.path.join(models_dir, repo_dir_name)
if not os.path.exists(repo_dir):
os.makedirs(repo_dir)
for sibling in repo_info.siblings:
filename = sibling.rfilename
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
download_path = os.path.join(repo_dir, filename)
response = requests.get(url, stream=True)
# Get the total file size from the content-length header
total_size = int(response.headers.get('content-length', 0))
# Download the file in chunks and print progress
with open(download_path, 'wb') as f:
downloaded_size = 0
for data in response.iter_content(chunk_size=10000000):
downloaded_size += len(data)
f.write(data)
progress = downloaded_size * 100 // total_size
downloaded_size_mb = downloaded_size / (1024 * 1024)
total_size_mb = total_size / (1024 * 1024)
print(f"\rDownloading {filename}... {progress}% complete "
f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
print(f"\rDownloading {filename}... Complete!")
print('Repo Downloaded')
except ValueError as e:
raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
yield("Downloading files")
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
yield("Done!")
except:
yield traceback.format_exc()
def create_model_menus():
@ -241,17 +216,20 @@ def create_model_menus():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
with gr.Row():
with gr.Column(scale=0.5):
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'")
with gr.Column():
with gr.Row():
with gr.Column(scale=0.5):
with gr.Column():
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
info="Enter Hugging Face username/model path e.g: facebook/galactica-125m")
with gr.Column():
shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
show_progress=True)
shared.gradio['download_status'] = gr.Markdown()
with gr.Column():
pass
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
def create_settings_menus(default_preset):