Add LoRA support

This commit is contained in:
oobabooga 2023-03-16 21:31:39 -03:00
parent ee164d1821
commit 104293f411
6 changed files with 51 additions and 8 deletions

View File

@ -1,12 +1,15 @@
.tabs.svelte-710i53 { .tabs.svelte-710i53 {
margin-top: 0 margin-top: 0
} }
.py-6 { .py-6 {
padding-top: 2.5rem padding-top: 2.5rem
} }
.dark #refresh-button { .dark #refresh-button {
background-color: #ffffff1f; background-color: #ffffff1f;
} }
#refresh-button { #refresh-button {
flex: none; flex: none;
margin: 0; margin: 0;
@ -17,22 +20,28 @@
border-radius: 10px; border-radius: 10px;
background-color: #0000000d; background-color: #0000000d;
} }
#download-label, #upload-label { #download-label, #upload-label {
min-height: 0 min-height: 0
} }
#accordion { #accordion {
} }
.dark svg { .dark svg {
fill: white; fill: white;
} }
svg { svg {
display: unset !important; display: unset !important;
vertical-align: middle !important; vertical-align: middle !important;
margin: 5px; margin: 5px;
} }
ol li p, ul li p { ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #interface-mode {
#main, #parameters, #chat-settings, #interface-mode, #lora {
border: 0; border: 0;
} }

View File

@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_safetensors = False has_safetensors = False
is_lora = False
while True: while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content content = requests.get(f"{base}{page}{cursor.decode()}").content
@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
for i in range(len(dict)): for i in range(len(dict)):
fname = dict[i]['path'] fname = dict[i]['path']
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match("pytorch_model.*\.bin", fname) is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname) is_safetensors = re.match("model.*\.safetensors", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor) cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D') cursor = cursor.replace(b'=', b'%3D')
@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
if classifications[i] == 'pytorch': if classifications[i] == 'pytorch':
links.pop(i) links.pop(i)
return links return links, is_lora
if __name__ == '__main__': if __name__ == '__main__':
model = args.MODEL model = args.MODEL
@ -159,15 +163,16 @@ if __name__ == '__main__':
except ValueError as err_branch: except ValueError as err_branch:
print(f"Error: {err_branch}") print(f"Error: {err_branch}")
sys.exit() sys.exit()
links, is_lora = get_download_links_from_huggingface(model, branch)
base_folder = 'models' if not is_lora else 'loras'
if branch != 'main': if branch != 'main':
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}') output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
else: else:
output_folder = Path("models") / model.split('/')[-1] output_folder = Path(base_folder) / model.split('/')[-1]
if not output_folder.exists(): if not output_folder.exists():
output_folder.mkdir() output_folder.mkdir()
links = get_download_links_from_huggingface(model, branch)
# Downloading the files # Downloading the files
print(f"Downloading the model to {output_folder}") print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads) pool = multiprocessing.Pool(processes=args.threads)

View File

@ -11,6 +11,8 @@ from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig) BitsAndBytesConfig)
from peft import PeftModel
import modules.shared as shared import modules.shared as shared
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()

View File

@ -2,7 +2,8 @@ import argparse
model = None model = None
tokenizer = None tokenizer = None
model_name = "" model_name = "None"
lora_name = "None"
soft_prompt_tensor = None soft_prompt_tensor = None
soft_prompt = False soft_prompt = False
is_RWKV = False is_RWKV = False

View File

@ -4,6 +4,7 @@ flexgen==0.1.7
gradio==3.18.0 gradio==3.18.0
markdown markdown
numpy numpy
peft==0.2.0
requests requests
rwkv==0.4.2 rwkv==0.4.2
safetensors==0.3.0 safetensors==0.3.0

View File

@ -17,6 +17,7 @@ import modules.ui as ui
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
from modules.LoRA import add_lora_to_model
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -48,6 +49,9 @@ def get_available_extensions():
def get_available_softprompts(): def get_available_softprompts():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
@ -59,6 +63,13 @@ def load_model_wrapper(selected_model):
return selected_model return selected_model
def load_lora_wrapper(selected_lora):
if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
add_lora_to_model(selected_lora)
return selected_lora
def load_preset_values(preset_menu, return_dict=False): def load_preset_values(preset_menu, return_dict=False):
generate_params = { generate_params = {
'do_sample': True, 'do_sample': True,
@ -181,6 +192,7 @@ available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
available_characters = get_available_characters() available_characters = get_available_characters()
available_softprompts = get_available_softprompts() available_softprompts = get_available_softprompts()
available_loras = get_available_loras()
# Default extensions # Default extensions
extensions_module.available_extensions = get_available_extensions() extensions_module.available_extensions = get_available_extensions()
@ -401,6 +413,19 @@ def create_interface():
shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("LoRA", elem_id="lora"):
with gr.Row():
with gr.Column():
gr.Markdown("Load")
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
with gr.Column():
gr.Markdown("Train (TODO)")
gr.Button("Practice your button clicking skills")
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
with gr.Tab("Interface mode", elem_id="interface-mode"): with gr.Tab("Interface mode", elem_id="interface-mode"):
modes = ["default", "notebook", "chat", "cai_chat"] modes = ["default", "notebook", "chat", "cai_chat"]
current_mode = "default" current_mode = "default"