mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
download custom model menu (from hugging face) added in model tab
This commit is contained in:
parent
bce1b7fbb2
commit
7436dd5b4a
69
server.py
69
server.py
@ -10,7 +10,8 @@ import time
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import requests
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
@ -20,6 +21,7 @@ from modules.html_generator import chat_html_wrapper
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
from modules.text_generation import generate_reply, stop_everything_event
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
@ -172,6 +174,62 @@ def create_prompt_menus():
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
|
||||
def download_model_wrapper(repo_id):
|
||||
print(repo_id)
|
||||
if repo_id == '':
|
||||
print("Please enter a valid repo ID. This field cant be empty")
|
||||
else:
|
||||
try:
|
||||
print('Downloading repo')
|
||||
hf_api = HfApi()
|
||||
# Get repo info
|
||||
repo_info = hf_api.repo_info(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
revision="main"
|
||||
)
|
||||
# create model and repo folder and check for lora
|
||||
is_lora = False
|
||||
for file in repo_info.siblings:
|
||||
if 'adapter_model.bin' in file.rfilename:
|
||||
is_lora = True
|
||||
repo_dir_name = repo_id.replace("/", "--")
|
||||
if is_lora is True:
|
||||
models_dir = ".loras"
|
||||
else:
|
||||
models_dir = ".models"
|
||||
if not os.path.exists(models_dir):
|
||||
os.makedirs(models_dir)
|
||||
repo_dir = os.path.join(models_dir, repo_dir_name)
|
||||
if not os.path.exists(repo_dir):
|
||||
os.makedirs(repo_dir)
|
||||
|
||||
for sibling in repo_info.siblings:
|
||||
filename = sibling.rfilename
|
||||
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
|
||||
download_path = os.path.join(repo_dir, filename)
|
||||
response = requests.get(url, stream=True)
|
||||
# Get the total file size from the content-length header
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
# Download the file in chunks and print progress
|
||||
with open(download_path, 'wb') as f:
|
||||
downloaded_size = 0
|
||||
for data in response.iter_content(chunk_size=10000000):
|
||||
downloaded_size += len(data)
|
||||
f.write(data)
|
||||
progress = downloaded_size * 100 // total_size
|
||||
downloaded_size_mb = downloaded_size / (1024 * 1024)
|
||||
total_size_mb = total_size / (1024 * 1024)
|
||||
print(f"\rDownloading {filename}... {progress}% complete "
|
||||
f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
|
||||
print(f"\rDownloading {filename}... Complete!")
|
||||
|
||||
print('Repo Downloaded')
|
||||
except ValueError as e:
|
||||
raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
|
||||
|
||||
|
||||
def create_model_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -182,6 +240,15 @@ def create_model_menus():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
|
||||
info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
|
||||
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
|
||||
show_progress=True)
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||
|
Loading…
Reference in New Issue
Block a user