mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Add an "Evaluate" tab to calculate the perplexities of models (#1322)
This commit is contained in:
parent
ff0d0ac552
commit
c4f4f41389
140
modules/evaluate.py
Normal file
140
modules/evaluate.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import datetime
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.text_generation import encode
|
||||||
|
from server import get_model_specific_settings, update_model_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def load_past_evaluations():
|
||||||
|
if Path('logs/evaluations.csv').exists():
|
||||||
|
df = pd.read_csv(Path('logs/evaluations.csv'), dtype=str)
|
||||||
|
df['Perplexity'] = pd.to_numeric(df['Perplexity'])
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
||||||
|
past_evaluations = load_past_evaluations()
|
||||||
|
|
||||||
|
|
||||||
|
def save_past_evaluations(df):
|
||||||
|
df.to_csv(Path('logs/evaluations.csv'), index=False)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||||
|
'''
|
||||||
|
Based on:
|
||||||
|
https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models
|
||||||
|
'''
|
||||||
|
|
||||||
|
global past_evaluations
|
||||||
|
cumulative_log = ''
|
||||||
|
cumulative_log += "Loading the input dataset...\n"
|
||||||
|
yield cumulative_log
|
||||||
|
|
||||||
|
# Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py
|
||||||
|
if input_dataset == 'wikitext':
|
||||||
|
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
||||||
|
text = "\n\n".join(data['text'])
|
||||||
|
elif input_dataset == 'ptb':
|
||||||
|
data = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
|
||||||
|
text = "\n\n".join(data['sentence'])
|
||||||
|
elif input_dataset == 'ptb_new':
|
||||||
|
data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
||||||
|
text = " ".join(data['sentence'])
|
||||||
|
else:
|
||||||
|
with open(Path(f'training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
if is_in_past_evaluations(model, input_dataset, stride, _max_length):
|
||||||
|
cumulative_log += f"{model} has already been tested. Ignoring.\n"
|
||||||
|
yield cumulative_log
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model != 'current model':
|
||||||
|
try:
|
||||||
|
yield cumulative_log + f"Loading {model}...\n"
|
||||||
|
model_settings = get_model_specific_settings(model)
|
||||||
|
shared.settings.update(model_settings) # hijacking the interface defaults
|
||||||
|
update_model_parameters(model_settings) # hijacking the command-line arguments
|
||||||
|
shared.model_name = model
|
||||||
|
unload_model()
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
except:
|
||||||
|
cumulative_log += f"Failed to load {model}. Moving on.\n"
|
||||||
|
yield cumulative_log
|
||||||
|
continue
|
||||||
|
|
||||||
|
cumulative_log += f"Processing {model}...\n"
|
||||||
|
yield cumulative_log + "Tokenizing the input dataset...\n"
|
||||||
|
encodings = encode(text, add_special_tokens=False)
|
||||||
|
seq_len = encodings.shape[1]
|
||||||
|
max_length = _max_length or shared.model.config.max_position_embeddings
|
||||||
|
nlls = []
|
||||||
|
prev_end_loc = 0
|
||||||
|
for begin_loc in tqdm(range(0, seq_len, stride)):
|
||||||
|
yield cumulative_log + f"Evaluating... {100*begin_loc/seq_len:.2f}%"
|
||||||
|
end_loc = min(begin_loc + max_length, seq_len)
|
||||||
|
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
||||||
|
input_ids = encodings[:, begin_loc:end_loc]
|
||||||
|
target_ids = input_ids.clone()
|
||||||
|
target_ids[:, :-trg_len] = -100
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = shared.model(input_ids, labels=target_ids)
|
||||||
|
|
||||||
|
# loss is calculated using CrossEntropyLoss which averages over valid labels
|
||||||
|
# N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
|
||||||
|
# to the left by 1.
|
||||||
|
neg_log_likelihood = outputs.loss
|
||||||
|
|
||||||
|
nlls.append(neg_log_likelihood)
|
||||||
|
|
||||||
|
prev_end_loc = end_loc
|
||||||
|
if end_loc == seq_len:
|
||||||
|
break
|
||||||
|
|
||||||
|
ppl = torch.exp(torch.stack(nlls).mean())
|
||||||
|
add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length)
|
||||||
|
save_past_evaluations(past_evaluations)
|
||||||
|
cumulative_log += f"Done. The perplexity is: {float(ppl)}\n\n"
|
||||||
|
yield cumulative_log
|
||||||
|
|
||||||
|
|
||||||
|
def add_entry_to_past_evaluations(perplexity, model, dataset, stride, max_length):
|
||||||
|
global past_evaluations
|
||||||
|
entry = {
|
||||||
|
'Model': model,
|
||||||
|
'LoRAs': ', '.join(shared.lora_names) or '-',
|
||||||
|
'Dataset': dataset,
|
||||||
|
'Perplexity': perplexity,
|
||||||
|
'stride': str(stride),
|
||||||
|
'max_length': str(max_length),
|
||||||
|
'Date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
|
'Comment': ''
|
||||||
|
}
|
||||||
|
past_evaluations = pd.concat([past_evaluations, pd.DataFrame([entry])], ignore_index=True)
|
||||||
|
|
||||||
|
|
||||||
|
def is_in_past_evaluations(model, dataset, stride, max_length):
|
||||||
|
entries = past_evaluations[(past_evaluations['Model'] == model) &
|
||||||
|
(past_evaluations['Dataset'] == dataset) &
|
||||||
|
(past_evaluations['max_length'] == str(max_length)) &
|
||||||
|
(past_evaluations['stride'] == str(stride))]
|
||||||
|
|
||||||
|
if entries.shape[0] > 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_markdown_table():
|
||||||
|
sorted_df = past_evaluations.sort_values(by=['Dataset', 'stride', 'Perplexity', 'Date'])
|
||||||
|
return sorted_df
|
@ -53,7 +53,7 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# Load the model in simple 16-bit mode by default
|
# Load the model in simple 16-bit mode by default
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
||||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code)
|
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code)
|
||||||
if torch.has_mps:
|
if torch.has_mps:
|
||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
@ -81,11 +81,11 @@ def load_model(model_name):
|
|||||||
num_bits=4, group_size=64,
|
num_bits=4, group_size=64,
|
||||||
group_dim=2, symmetric=False))
|
group_dim=2, symmetric=False))
|
||||||
|
|
||||||
model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy)
|
model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
|
||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||||
@ -169,7 +169,7 @@ def load_model(model_name):
|
|||||||
if shared.args.disk:
|
if shared.args.disk:
|
||||||
params["offload_folder"] = shared.args.disk_cache_dir
|
params["offload_folder"] = shared.args.disk_cache_dir
|
||||||
|
|
||||||
checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}')
|
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
|
||||||
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||||
config = AutoConfig.from_pretrained(checkpoint)
|
config = AutoConfig.from_pretrained(checkpoint)
|
||||||
@ -190,7 +190,7 @@ def load_model(model_name):
|
|||||||
llama_attn_hijack.hijack_llama_attention()
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
elif type(model) is transformers.LlamaForCausalLM:
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
@ -205,7 +205,7 @@ def load_model(model_name):
|
|||||||
# Otherwise, load it from the model folder and hope that these
|
# Otherwise, load it from the model folder and hope that these
|
||||||
# are not outdated tokenizer files.
|
# are not outdated tokenizer files.
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
try:
|
try:
|
||||||
tokenizer.eos_token_id = 2
|
tokenizer.eos_token_id = 2
|
||||||
tokenizer.bos_token_id = 1
|
tokenizer.bos_token_id = 1
|
||||||
@ -213,7 +213,7 @@ def load_model(model_name):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code)
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -10,9 +10,12 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, prepare_model_for_int8_training
|
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
|
||||||
|
set_peft_model_state_dict)
|
||||||
|
|
||||||
from modules import shared, ui
|
from modules import shared, ui
|
||||||
|
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
|
||||||
|
from server import get_available_loras, get_available_models
|
||||||
|
|
||||||
# This mapping is from a very recent commit, not yet released.
|
# This mapping is from a very recent commit, not yet released.
|
||||||
# If not available, default to a backup map for the 3 safe model types.
|
# If not available, default to a backup map for the 3 safe model types.
|
||||||
@ -40,10 +43,6 @@ def get_datasets(path: str, ext: str):
|
|||||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_loras():
|
|
||||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
|
||||||
|
|
||||||
|
|
||||||
def create_train_interface():
|
def create_train_interface():
|
||||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -82,9 +81,9 @@ def create_train_interface():
|
|||||||
|
|
||||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||||
|
|
||||||
with gr.Tab(label='Raw Text File'):
|
with gr.Tab(label="Raw text file"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -106,11 +105,48 @@ def create_train_interface():
|
|||||||
|
|
||||||
output = gr.Markdown(value="Ready")
|
output = gr.Markdown(value="Ready")
|
||||||
|
|
||||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
|
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
|
||||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
with gr.Row():
|
||||||
start_button.click(do_train, all_params, output)
|
with gr.Column():
|
||||||
stop_button.click(do_interrupt, None, None, queue=False)
|
models = gr.Dropdown(get_available_models(), label='Models', multiselect=True)
|
||||||
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
|
||||||
|
with gr.Row():
|
||||||
|
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||||
|
max_length = gr.Slider(label='max_length', minimum=1, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
start_current_evaluation = gr.Button("Evaluate loaded model")
|
||||||
|
start_evaluation = gr.Button("Evaluate selected models")
|
||||||
|
stop_evaluation = gr.Button("Interrupt")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
evaluation_log = gr.Markdown(value = '')
|
||||||
|
|
||||||
|
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
||||||
|
save_comments = gr.Button('Save comments')
|
||||||
|
|
||||||
|
# Training events
|
||||||
|
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
|
||||||
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
|
start_button.click(do_train, all_params, output)
|
||||||
|
stop_button.click(do_interrupt, None, None, queue=False)
|
||||||
|
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
||||||
|
|
||||||
|
# Evaluation events. For some reason, the interrupt event
|
||||||
|
# doesn't work with the .then() syntax, so I write them one
|
||||||
|
# by one in this ugly but functional way.
|
||||||
|
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||||
|
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||||
|
|
||||||
|
tmp = gr.State('')
|
||||||
|
start_current_evaluation.click(lambda: ['current model'], None, tmp)
|
||||||
|
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||||
|
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||||
|
|
||||||
|
stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
|
||||||
|
save_comments.click(
|
||||||
|
save_past_evaluations, evaluation_table, None).then(
|
||||||
|
lambda: "Comments saved.", None, evaluation_log, show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
def do_interrupt():
|
def do_interrupt():
|
||||||
@ -133,6 +169,7 @@ def do_copy_params(lora_name: str, *args):
|
|||||||
result.append(params[key])
|
result.append(params[key])
|
||||||
else:
|
else:
|
||||||
result.append(args[i])
|
result.append(args[i])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -155,7 +192,8 @@ def clean_path(base_path: str, path: str):
|
|||||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
|
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
|
||||||
|
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model
|
from monkeypatch.peft_tuners_lora_monkey_patch import \
|
||||||
|
replace_peft_model_with_gptq_lora_model
|
||||||
replace_peft_model_with_gptq_lora_model()
|
replace_peft_model_with_gptq_lora_model()
|
||||||
|
|
||||||
global WANT_INTERRUPT
|
global WANT_INTERRUPT
|
||||||
@ -300,6 +338,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
if '4bit' in str(type(m)):
|
if '4bit' in str(type(m)):
|
||||||
if m.is_v1_model:
|
if m.is_v1_model:
|
||||||
m.zeros = m.zeros.half()
|
m.zeros = m.zeros.half()
|
||||||
|
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
|
|
||||||
class Tracked():
|
class Tracked():
|
||||||
|
@ -20,7 +20,8 @@ theme = gr.themes.Default(
|
|||||||
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||||
).set(
|
).set(
|
||||||
border_color_primary='#c5c5d2',
|
border_color_primary='#c5c5d2',
|
||||||
button_large_padding='6px 12px'
|
button_large_padding='6px 12px',
|
||||||
|
body_text_color_subdued='#484848'
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_model_elements():
|
def list_model_elements():
|
||||||
|
@ -5,12 +5,13 @@ flexgen==0.1.7
|
|||||||
gradio==3.25.0
|
gradio==3.25.0
|
||||||
markdown
|
markdown
|
||||||
numpy
|
numpy
|
||||||
|
pandas
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rwkv==0.7.3
|
rwkv==0.7.3
|
||||||
safetensors==0.3.0
|
safetensors==0.3.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
pyyaml
|
|
||||||
tqdm
|
tqdm
|
||||||
git+https://github.com/huggingface/peft
|
git+https://github.com/huggingface/peft
|
||||||
transformers==4.28.1
|
transformers==4.28.1
|
||||||
|
Loading…
Reference in New Issue
Block a user