Sort some imports

This commit is contained in:
oobabooga 2023-06-25 01:44:36 -03:00
parent 365b672531
commit f0fcd1f697
8 changed files with 60 additions and 33 deletions

View File

@ -14,8 +14,11 @@ import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper, make_thumbnail from modules.html_generator import chat_html_wrapper, make_thumbnail
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import (generate_reply, get_encoded_length, from modules.text_generation import (
get_max_prompt_length) generate_reply,
get_encoded_length,
get_max_prompt_length
)
from modules.utils import delete_file, replace_all, save_file from modules.utils import delete_file, replace_all, save_file

View File

@ -8,8 +8,10 @@ from tqdm import tqdm
from modules import shared from modules import shared
from modules.models import load_model, unload_model from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls, from modules.models_settings import (
update_model_parameters) get_model_settings_from_yamls,
update_model_parameters
)
from modules.text_generation import encode from modules.text_generation import encode

View File

@ -1,9 +1,3 @@
'''
This is a library for formatting text outputs as nice HTML.
'''
import os import os
import re import re
import time import time

View File

@ -7,9 +7,15 @@ from pathlib import Path
import torch import torch
import transformers import transformers
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, from transformers import (
AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig,
BitsAndBytesConfig, LlamaTokenizer) AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BitsAndBytesConfig,
LlamaTokenizer
)
import modules.shared as shared import modules.shared as shared
from modules import llama_attn_hijack, sampler_hijack from modules import llama_attn_hijack, sampler_hijack
@ -21,8 +27,10 @@ transformers.logging.set_verbosity_error()
local_rank = None local_rank = None
if shared.args.deepspeed: if shared.args.deepspeed:
import deepspeed import deepspeed
from transformers.deepspeed import (HfDeepSpeedConfig, from transformers.deepspeed import (
is_deepspeed_zero3_enabled) HfDeepSpeedConfig,
is_deepspeed_zero3_enabled
)
from modules.deepspeed_parameters import generate_ds_config from modules.deepspeed_parameters import generate_ds_config

View File

@ -7,10 +7,14 @@ sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit")))
import autograd_4bit import autograd_4bit
from amp_wrapper import AMPWrapper from amp_wrapper import AMPWrapper
from autograd_4bit import (Autograd4bitQuantLinear, from autograd_4bit import (
load_llama_model_4bit_low_ram) Autograd4bitQuantLinear,
load_llama_model_4bit_low_ram
)
from monkeypatch.peft_tuners_lora_monkey_patch import ( from monkeypatch.peft_tuners_lora_monkey_patch import (
Linear4bitLt, replace_peft_model_with_gptq_lora_model) Linear4bitLt,
replace_peft_model_with_gptq_lora_model
)
from modules import shared from modules import shared
from modules.GPTQ_loader import find_quantized_model_file from modules.GPTQ_loader import find_quantized_model_file

View File

@ -3,9 +3,11 @@ import math
import torch import torch
import transformers import transformers
from transformers import LogitsWarper from transformers import LogitsWarper
from transformers.generation.logits_process import (LogitNormalization, from transformers.generation.logits_process import (
LogitsProcessorList, LogitNormalization,
TemperatureLogitsWarper) LogitsProcessorList,
TemperatureLogitsWarper
)
class TailFreeLogitsWarper(LogitsWarper): class TailFreeLogitsWarper(LogitsWarper):

View File

@ -11,12 +11,19 @@ import gradio as gr
import torch import torch
import transformers import transformers
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training, from peft import (
set_peft_model_state_dict) LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
set_peft_model_state_dict
)
from modules import shared, ui, utils from modules import shared, ui, utils
from modules.evaluate import (calculate_perplexity, generate_markdown_table, from modules.evaluate import (
save_past_evaluations) calculate_perplexity,
generate_markdown_table,
save_past_evaluations
)
from modules.logging_colors import logger from modules.logging_colors import logger
# This mapping is from a very recent commit, not yet released. # This mapping is from a very recent commit, not yet released.
@ -25,8 +32,9 @@ try:
from peft.utils.other import \ from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules model_to_lora_modules
from transformers.models.auto.modeling_auto import \ from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
except: except:
standard_modules = ["q_proj", "v_proj"] standard_modules = ["q_proj", "v_proj"]
@ -201,8 +209,9 @@ def clean_path(base_path: str, path: str):
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str):
if shared.args.monkey_patch: if shared.args.monkey_patch:
from monkeypatch.peft_tuners_lora_monkey_patch import \ from monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_gptq_lora_model replace_peft_model_with_gptq_lora_model
)
replace_peft_model_with_gptq_lora_model() replace_peft_model_with_gptq_lora_model()
global WANT_INTERRUPT global WANT_INTERRUPT

View File

@ -38,12 +38,17 @@ from modules.github import clone_or_pull_repository
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model from modules.models import load_model, unload_model
from modules.models_settings import (apply_model_settings_to_state, from modules.models_settings import (
get_model_settings_from_yamls, apply_model_settings_to_state,
save_model_settings, get_model_settings_from_yamls,
update_model_parameters) save_model_settings,
from modules.text_generation import (generate_reply_wrapper, update_model_parameters
get_encoded_length, stop_everything_event) )
from modules.text_generation import (
generate_reply_wrapper,
get_encoded_length,
stop_everything_event
)
def load_model_wrapper(selected_model, loader, autoload=False): def load_model_wrapper(selected_model, loader, autoload=False):