mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Use download-model.py to download the model
This commit is contained in:
parent
c6e9ba20a4
commit
2c14df81a8
@ -20,17 +20,6 @@ import tqdm
|
|||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
|
||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
|
||||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
|
||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def select_model_from_default_options():
|
def select_model_from_default_options():
|
||||||
models = {
|
models = {
|
||||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||||
@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||||
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
|
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||||
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
branch = args.branch
|
branch = args.branch
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
if model is None:
|
if model is None:
|
||||||
|
98
server.py
98
server.py
@ -2,17 +2,21 @@ import os
|
|||||||
|
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
|
import importlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
@ -21,7 +25,6 @@ from modules.html_generator import chat_html_wrapper
|
|||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt, unload_model
|
from modules.models import load_model, load_soft_prompt, unload_model
|
||||||
from modules.text_generation import generate_reply, stop_everything_event
|
from modules.text_generation import generate_reply, stop_everything_event
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
@ -175,59 +178,31 @@ def create_prompt_menus():
|
|||||||
|
|
||||||
|
|
||||||
def download_model_wrapper(repo_id):
|
def download_model_wrapper(repo_id):
|
||||||
print(repo_id)
|
|
||||||
if repo_id == '':
|
|
||||||
print("Please enter a valid repo ID. This field cant be empty")
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
print('Downloading repo')
|
downloader = importlib.import_module("download-model")
|
||||||
hf_api = HfApi()
|
|
||||||
# Get repo info
|
model = repo_id
|
||||||
repo_info = hf_api.repo_info(
|
branch = "main"
|
||||||
repo_id=repo_id,
|
check = False
|
||||||
repo_type="model",
|
|
||||||
revision="main"
|
yield("Cleaning up the model/branch names")
|
||||||
)
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
# create model and repo folder and check for lora
|
|
||||||
is_lora = False
|
yield("Getting the download links from Hugging Face")
|
||||||
for file in repo_info.siblings:
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
if 'adapter_model.bin' in file.rfilename:
|
|
||||||
is_lora = True
|
yield("Getting the output folder")
|
||||||
repo_dir_name = repo_id.replace("/", "--")
|
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||||
if is_lora is True:
|
|
||||||
models_dir = "loras"
|
if check:
|
||||||
|
yield("Checking previously downloaded files")
|
||||||
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
models_dir = "models"
|
yield("Downloading files")
|
||||||
if not os.path.exists(models_dir):
|
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||||
os.makedirs(models_dir)
|
yield("Done!")
|
||||||
repo_dir = os.path.join(models_dir, repo_dir_name)
|
except:
|
||||||
if not os.path.exists(repo_dir):
|
yield traceback.format_exc()
|
||||||
os.makedirs(repo_dir)
|
|
||||||
|
|
||||||
for sibling in repo_info.siblings:
|
|
||||||
filename = sibling.rfilename
|
|
||||||
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
|
|
||||||
download_path = os.path.join(repo_dir, filename)
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
# Get the total file size from the content-length header
|
|
||||||
total_size = int(response.headers.get('content-length', 0))
|
|
||||||
|
|
||||||
# Download the file in chunks and print progress
|
|
||||||
with open(download_path, 'wb') as f:
|
|
||||||
downloaded_size = 0
|
|
||||||
for data in response.iter_content(chunk_size=10000000):
|
|
||||||
downloaded_size += len(data)
|
|
||||||
f.write(data)
|
|
||||||
progress = downloaded_size * 100 // total_size
|
|
||||||
downloaded_size_mb = downloaded_size / (1024 * 1024)
|
|
||||||
total_size_mb = total_size / (1024 * 1024)
|
|
||||||
print(f"\rDownloading {filename}... {progress}% complete "
|
|
||||||
f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
|
|
||||||
print(f"\rDownloading {filename}... Complete!")
|
|
||||||
|
|
||||||
print('Repo Downloaded')
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_menus():
|
def create_model_menus():
|
||||||
@ -241,17 +216,20 @@ def create_model_menus():
|
|||||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=0.5):
|
with gr.Column():
|
||||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
|
|
||||||
info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=0.5):
|
with gr.Column():
|
||||||
|
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
|
||||||
|
info="Enter Hugging Face username/model path e.g: facebook/galactica-125m")
|
||||||
|
with gr.Column():
|
||||||
shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
|
shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
|
||||||
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
|
shared.gradio['download_status'] = gr.Markdown()
|
||||||
show_progress=True)
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
|
||||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||||
|
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
|
Loading…
Reference in New Issue
Block a user