mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Download models from Model tab (#954 from UsamaKenway/main)
This commit is contained in:
commit
7e70741a4e
@ -20,17 +20,6 @@ import tqdm
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def select_model_from_default_options():
|
||||
models = {
|
||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||
@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
branch = args.branch
|
||||
model = args.MODEL
|
||||
if model is None:
|
||||
|
43
server.py
43
server.py
@ -2,11 +2,14 @@ import os
|
||||
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -172,6 +175,34 @@ def create_prompt_menus():
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
|
||||
def download_model_wrapper(repo_id):
|
||||
try:
|
||||
downloader = importlib.import_module("download-model")
|
||||
|
||||
model = repo_id
|
||||
branch = "main"
|
||||
check = False
|
||||
|
||||
yield("Cleaning up the model/branch names")
|
||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||
|
||||
yield("Getting the download links from Hugging Face")
|
||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||
|
||||
yield("Getting the output folder")
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||
|
||||
if check:
|
||||
yield("Checking previously downloaded files")
|
||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||
else:
|
||||
yield("Downloading files")
|
||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||
yield("Done!")
|
||||
except:
|
||||
yield traceback.format_exc()
|
||||
|
||||
|
||||
def create_model_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -182,9 +213,21 @@ def create_model_menus():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
|
||||
info="Enter Hugging Face username/model path e.g: facebook/galactica-125m")
|
||||
with gr.Column():
|
||||
shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
|
||||
shared.gradio['download_status'] = gr.Markdown()
|
||||
with gr.Column():
|
||||
pass
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
|
Loading…
Reference in New Issue
Block a user