Make paths cross-platform (should work on Windows now)

This commit is contained in:
oobabooga 2023-01-07 16:33:43 -03:00
parent 89fb0a13c5
commit 5345685ead
3 changed files with 28 additions and 31 deletions

View File

@ -10,11 +10,10 @@ Output will be written to torch-dumps/name-of-the-model.pt
from transformers import AutoModelForCausalLM, T5ForConditionalGeneration from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
import torch import torch
from sys import argv from sys import argv
from pathlib import Path
path = argv[1] path = Path(argv[1])
if path[-1] != '/': model_name = path.name
path = path+'/'
model_name = path.split('/')[-2]
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
if model_name in ['flan-t5', 't5-large']: if model_name in ['flan-t5', 't5-large']:
@ -24,4 +23,4 @@ else:
print("Model loaded.") print("Model loaded.")
print(f"Saving to torch-dumps/{model_name}.pt") print(f"Saving to torch-dumps/{model_name}.pt")
torch.save(model, f"torch-dumps/{model_name}.pt") torch.save(model, Path(f"torch-dumps/{model_name}.pt"))

View File

@ -9,16 +9,16 @@ python download-model.py facebook/opt-1.3b
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import multiprocessing import multiprocessing
import os
import tqdm import tqdm
from sys import argv from sys import argv
from pathlib import Path
def get_file(args): def get_file(args):
url = args[0] url = args[0]
output_folder = args[1] output_folder = args[1]
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(f"{output_folder}/{url.split('/')[-1]}", 'wb') as f: with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 block_size = 1024
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
@ -27,13 +27,11 @@ def get_file(args):
f.write(data) f.write(data)
t.close() t.close()
model = argv[1] model = Path(argv[1])
if model.endswith('/'):
model = model[:-1]
url = f'https://huggingface.co/{model}/tree/main' url = f'https://huggingface.co/{model}/tree/main'
output_folder = f"models/{model.split('/')[-1]}" output_folder = Path("models") / model.name
if not os.path.exists(output_folder): if not output_folder.exists():
os.mkdir(output_folder) output_folder.mkdir()
# Finding the relevant files to download # Finding the relevant files to download
page = requests.get(url) page = requests.get(url)

View File

@ -1,15 +1,15 @@
import os
import re import re
import time import time
import glob import glob
from sys import exit from sys import exit
import torch import torch
import argparse import argparse
from pathlib import Path
import gradio as gr import gradio as gr
import transformers import transformers
from html_generator import * from html_generator import *
from transformers import AutoTokenizer from transformers import AutoTokenizer, T5Tokenizer
from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel from transformers import AutoModelForCausalLM, T5ForConditionalGeneration
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -17,37 +17,37 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa
parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
args = parser.parse_args() args = parser.parse_args()
loaded_preset = None loaded_preset = None
available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*")+ glob.glob("torch-dumps/*")))) available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
available_models = [item for item in available_models if not item.endswith('.txt')] available_models = [item for item in available_models if not item.endswith('.txt')]
#available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]")))) available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
# Loading the model # Loading the model
if os.path.exists(f"torch-dumps/{model_name}.pt"): if Path(f"torch-dumps/{model_name}.pt").exists():
print("Loading in .pt format...") print("Loading in .pt format...")
model = torch.load(f"torch-dumps/{model_name}.pt").cuda() model = torch.load(Path(f"torch-dumps/{model_name}.pt")).cuda()
elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')): elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')):
if any(size in model_name.lower() for size in ('13b', '20b', '30b')): if any(size in model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
else: else:
model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
elif model_name in ['gpt-j-6B']: elif model_name in ['gpt-j-6B']:
model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
elif model_name in ['flan-t5', 't5-large']: elif model_name in ['flan-t5', 't5-large']:
model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda() model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda()
else: else:
model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
# Loading the tokenizer # Loading the tokenizer
if model_name.startswith('gpt4chan'): if model_name.startswith('gpt4chan'):
tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/") tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
elif model_name in ['flan-t5']: elif model_name in ['flan-t5']:
tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/") tokenizer = T5Tokenizer.from_pretrained(Path(f"models/{model_name}/"))
else: else:
tokenizer = AutoTokenizer.from_pretrained(f"models/{model_name}/") tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
@ -78,7 +78,7 @@ def generate_reply(question, temperature, max_length, inference_settings, select
torch.cuda.empty_cache() torch.cuda.empty_cache()
model, tokenizer = load_model(model_name) model, tokenizer = load_model(model_name)
if inference_settings != loaded_preset: if inference_settings != loaded_preset:
with open(f'presets/{inference_settings}.txt', 'r') as infile: with open(Path(f'presets/{inference_settings}.txt'), 'r') as infile:
preset = infile.read() preset = infile.read()
loaded_preset = inference_settings loaded_preset = inference_settings
@ -143,7 +143,7 @@ if args.notebook:
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7) temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200) length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
with gr.Column(): with gr.Column():
preset_menu = gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="NovelAI-Sphinx Moth", label='Preset') preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False) btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
@ -161,7 +161,7 @@ else:
textbox = gr.Textbox(value=default_text, lines=15, label='Input') textbox = gr.Textbox(value=default_text, lines=15, label='Input')
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7) temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200) length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
preset_menu = gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="NovelAI-Sphinx Moth", label='Preset') preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
btn = gr.Button("Generate") btn = gr.Button("Generate")
with gr.Column(): with gr.Column():