mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Sort some imports
This commit is contained in:
parent
365b672531
commit
f0fcd1f697
@ -14,8 +14,11 @@ import modules.shared as shared
|
|||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import (generate_reply, get_encoded_length,
|
from modules.text_generation import (
|
||||||
get_max_prompt_length)
|
generate_reply,
|
||||||
|
get_encoded_length,
|
||||||
|
get_max_prompt_length
|
||||||
|
)
|
||||||
from modules.utils import delete_file, replace_all, save_file
|
from modules.utils import delete_file, replace_all, save_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,8 +8,10 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
from modules.models_settings import (get_model_settings_from_yamls,
|
from modules.models_settings import (
|
||||||
update_model_parameters)
|
get_model_settings_from_yamls,
|
||||||
|
update_model_parameters
|
||||||
|
)
|
||||||
from modules.text_generation import encode
|
from modules.text_generation import encode
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,3 @@
|
|||||||
'''
|
|
||||||
|
|
||||||
This is a library for formatting text outputs as nice HTML.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
@ -7,9 +7,15 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
from transformers import (
|
||||||
AutoModelForSeq2SeqLM, AutoTokenizer,
|
AutoConfig,
|
||||||
BitsAndBytesConfig, LlamaTokenizer)
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
LlamaTokenizer
|
||||||
|
)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import llama_attn_hijack, sampler_hijack
|
from modules import llama_attn_hijack, sampler_hijack
|
||||||
@ -21,8 +27,10 @@ transformers.logging.set_verbosity_error()
|
|||||||
local_rank = None
|
local_rank = None
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
from transformers.deepspeed import (
|
||||||
is_deepspeed_zero3_enabled)
|
HfDeepSpeedConfig,
|
||||||
|
is_deepspeed_zero3_enabled
|
||||||
|
)
|
||||||
|
|
||||||
from modules.deepspeed_parameters import generate_ds_config
|
from modules.deepspeed_parameters import generate_ds_config
|
||||||
|
|
||||||
|
@ -7,10 +7,14 @@ sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit")))
|
|||||||
|
|
||||||
import autograd_4bit
|
import autograd_4bit
|
||||||
from amp_wrapper import AMPWrapper
|
from amp_wrapper import AMPWrapper
|
||||||
from autograd_4bit import (Autograd4bitQuantLinear,
|
from autograd_4bit import (
|
||||||
load_llama_model_4bit_low_ram)
|
Autograd4bitQuantLinear,
|
||||||
|
load_llama_model_4bit_low_ram
|
||||||
|
)
|
||||||
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
Linear4bitLt, replace_peft_model_with_gptq_lora_model)
|
Linear4bitLt,
|
||||||
|
replace_peft_model_with_gptq_lora_model
|
||||||
|
)
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.GPTQ_loader import find_quantized_model_file
|
from modules.GPTQ_loader import find_quantized_model_file
|
||||||
|
@ -3,9 +3,11 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsWarper
|
||||||
from transformers.generation.logits_process import (LogitNormalization,
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitNormalization,
|
||||||
TemperatureLogitsWarper)
|
LogitsProcessorList,
|
||||||
|
TemperatureLogitsWarper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
class TailFreeLogitsWarper(LogitsWarper):
|
||||||
|
@ -11,12 +11,19 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training,
|
from peft import (
|
||||||
set_peft_model_state_dict)
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
prepare_model_for_kbit_training,
|
||||||
|
set_peft_model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
from modules import shared, ui, utils
|
from modules import shared, ui, utils
|
||||||
from modules.evaluate import (calculate_perplexity, generate_markdown_table,
|
from modules.evaluate import (
|
||||||
save_past_evaluations)
|
calculate_perplexity,
|
||||||
|
generate_markdown_table,
|
||||||
|
save_past_evaluations
|
||||||
|
)
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
# This mapping is from a very recent commit, not yet released.
|
# This mapping is from a very recent commit, not yet released.
|
||||||
@ -25,8 +32,9 @@ try:
|
|||||||
from peft.utils.other import \
|
from peft.utils.other import \
|
||||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
||||||
model_to_lora_modules
|
model_to_lora_modules
|
||||||
from transformers.models.auto.modeling_auto import \
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
)
|
||||||
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
|
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
|
||||||
except:
|
except:
|
||||||
standard_modules = ["q_proj", "v_proj"]
|
standard_modules = ["q_proj", "v_proj"]
|
||||||
@ -201,8 +209,9 @@ def clean_path(base_path: str, path: str):
|
|||||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str):
|
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str):
|
||||||
|
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
from monkeypatch.peft_tuners_lora_monkey_patch import \
|
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
replace_peft_model_with_gptq_lora_model
|
replace_peft_model_with_gptq_lora_model
|
||||||
|
)
|
||||||
replace_peft_model_with_gptq_lora_model()
|
replace_peft_model_with_gptq_lora_model()
|
||||||
|
|
||||||
global WANT_INTERRUPT
|
global WANT_INTERRUPT
|
||||||
|
17
server.py
17
server.py
@ -38,12 +38,17 @@ from modules.github import clone_or_pull_repository
|
|||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
from modules.models_settings import (apply_model_settings_to_state,
|
from modules.models_settings import (
|
||||||
get_model_settings_from_yamls,
|
apply_model_settings_to_state,
|
||||||
save_model_settings,
|
get_model_settings_from_yamls,
|
||||||
update_model_parameters)
|
save_model_settings,
|
||||||
from modules.text_generation import (generate_reply_wrapper,
|
update_model_parameters
|
||||||
get_encoded_length, stop_everything_event)
|
)
|
||||||
|
from modules.text_generation import (
|
||||||
|
generate_reply_wrapper,
|
||||||
|
get_encoded_length,
|
||||||
|
stop_everything_event
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model_wrapper(selected_model, loader, autoload=False):
|
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user