mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-28 18:48:04 +01:00
Prevent unwanted log messages from modules
This commit is contained in:
parent
fb91406e93
commit
e116d31180
@ -1,8 +1,8 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
logger.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -10,6 +9,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import encode, get_max_prompt_length
|
from modules.text_generation import encode, get_max_prompt_length
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class MultimodalEmbedder:
|
|||||||
def __init__(self, params: dict):
|
def __init__(self, params: dict):
|
||||||
pipeline, source = load_pipeline(params)
|
pipeline, source = load_pipeline(params)
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
logging.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
|
logger.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
|
||||||
|
|
||||||
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
|
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
|
||||||
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
|
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
|
||||||
@ -138,7 +138,7 @@ class MultimodalEmbedder:
|
|||||||
|
|
||||||
# notify user if we truncated an image
|
# notify user if we truncated an image
|
||||||
if removed_images > 0:
|
if removed_images > 0:
|
||||||
logging.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
|
logger.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
|
||||||
|
|
||||||
return encoded
|
return encoded
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import traceback
|
import traceback
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -6,6 +5,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def _get_available_pipeline_modules():
|
def _get_available_pipeline_modules():
|
||||||
@ -21,8 +21,8 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
|||||||
try:
|
try:
|
||||||
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
|
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
|
||||||
except:
|
except:
|
||||||
logging.warning(f'Failed to get multimodal pipelines from {name}')
|
logger.warning(f'Failed to get multimodal pipelines from {name}')
|
||||||
logging.warning(traceback.format_exc())
|
logger.warning(traceback.format_exc())
|
||||||
|
|
||||||
if shared.args.multimodal_pipeline is not None:
|
if shared.args.multimodal_pipeline is not None:
|
||||||
for k in pipeline_modules:
|
for k in pipeline_modules:
|
||||||
@ -48,5 +48,5 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
|||||||
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
|
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
|
||||||
else:
|
else:
|
||||||
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
|
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
|
||||||
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension')
|
logger.critical(f'{log} Please specify a correct pipeline, or disable the extension')
|
||||||
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')
|
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from modules import shared
|
|
||||||
from modules.text_generation import encode
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||||
|
|
||||||
|
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import encode
|
||||||
|
|
||||||
|
|
||||||
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||||
CLIP_REPO = "openai/clip-vit-large-patch14"
|
CLIP_REPO = "openai/clip-vit-large-patch14"
|
||||||
@ -26,11 +27,11 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||||||
def _load_models(self):
|
def _load_models(self):
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
|
|
||||||
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
logger.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||||
|
|
||||||
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||||
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
||||||
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
||||||
projector_data = torch.load(projector_path)
|
projector_data = torch.load(projector_path)
|
||||||
@ -38,7 +39,7 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||||
mm_projector = mm_projector.to(self.projector_device)
|
mm_projector = mm_projector.to(self.projector_device)
|
||||||
|
|
||||||
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||||
return image_processor, vision_tower, mm_projector
|
return image_processor, vision_tower, mm_projector
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -10,6 +9,7 @@ import torch
|
|||||||
|
|
||||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"add_all_images_to_prompt": False,
|
"add_all_images_to_prompt": False,
|
||||||
@ -78,7 +78,7 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
|||||||
return prompt, input_ids, input_embeds
|
return prompt, input_ids, input_embeds
|
||||||
|
|
||||||
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
||||||
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||||
return (prompt,
|
return (prompt,
|
||||||
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
||||||
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
import logging
|
import chromadb
|
||||||
|
|
||||||
import posthog
|
import posthog
|
||||||
import torch
|
import torch
|
||||||
|
from chromadb.config import Settings
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
import chromadb
|
from modules.logging_colors import logger
|
||||||
from chromadb.config import Settings
|
|
||||||
|
|
||||||
logging.info('Intercepting all calls to posthog :)')
|
logger.info('Intercepting all calls to posthog :)')
|
||||||
posthog.capture = lambda *args, **kwargs: None
|
posthog.capture = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
@ -6,6 +5,7 @@ import gradio as gr
|
|||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
from .chromadb import add_chunks_to_collector, make_collector
|
from .chromadb import add_chunks_to_collector, make_collector
|
||||||
from .download_urls import download_urls
|
from .download_urls import download_urls
|
||||||
@ -123,14 +123,14 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
additional_context += make_single_exchange(id_)
|
additional_context += make_single_exchange(id_)
|
||||||
|
|
||||||
logging.warning(f'Adding the following new context:\n{additional_context}')
|
logger.warning(f'Adding the following new context:\n{additional_context}')
|
||||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||||
kwargs['history'] = {
|
kwargs['history'] = {
|
||||||
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||||
'visible': ''
|
'visible': ''
|
||||||
}
|
}
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logging.error("Couldn't query the database, moving on...")
|
logger.error("Couldn't query the database, moving on...")
|
||||||
|
|
||||||
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
return chat.generate_chat_prompt(user_input, state, **kwargs)
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
from modules.models import get_max_memory_dict
|
from modules.models import get_max_memory_dict
|
||||||
|
|
||||||
|
|
||||||
@ -17,13 +17,13 @@ def load_quantized(model_name):
|
|||||||
found = list(path_to_model.glob(f"*{ext}"))
|
found = list(path_to_model.glob(f"*{ext}"))
|
||||||
if len(found) > 0:
|
if len(found) > 0:
|
||||||
if len(found) > 1:
|
if len(found) > 1:
|
||||||
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||||
|
|
||||||
pt_path = found[-1]
|
pt_path = found[-1]
|
||||||
break
|
break
|
||||||
|
|
||||||
if pt_path is None:
|
if pt_path is None:
|
||||||
logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
|
logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Define the params for AutoGPTQForCausalLM.from_quantized
|
# Define the params for AutoGPTQForCausalLM.from_quantized
|
||||||
@ -35,6 +35,6 @@ def load_quantized(model_name):
|
|||||||
'max_memory': get_max_memory_dict()
|
'max_memory': get_max_memory_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.warning(f"The AutoGPTQ params are: {params}")
|
logger.warning(f"The AutoGPTQ params are: {params}")
|
||||||
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
|
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
|
||||||
return model
|
return model
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,14 +9,15 @@ import transformers
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import llama_inference_offload
|
import llama_inference_offload
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error('Failed to load GPTQ-for-LLaMa')
|
logger.error('Failed to load GPTQ-for-LLaMa')
|
||||||
logging.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md')
|
logger.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -127,7 +127,7 @@ def find_quantized_model_file(model_name):
|
|||||||
found = list(path_to_model.glob(f"*{ext}"))
|
found = list(path_to_model.glob(f"*{ext}"))
|
||||||
if len(found) > 0:
|
if len(found) > 0:
|
||||||
if len(found) > 1:
|
if len(found) > 1:
|
||||||
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||||
|
|
||||||
pt_path = found[-1]
|
pt_path = found[-1]
|
||||||
break
|
break
|
||||||
@ -138,8 +138,8 @@ def find_quantized_model_file(model_name):
|
|||||||
# The function that loads the model in modules/models.py
|
# The function that loads the model in modules/models.py
|
||||||
def load_quantized(model_name):
|
def load_quantized(model_name):
|
||||||
if shared.args.model_type is None:
|
if shared.args.model_type is None:
|
||||||
logging.error("The model could not be loaded because its type could not be inferred from its name.")
|
logger.error("The model could not be loaded because its type could not be inferred from its name.")
|
||||||
logging.error("Please specify the type manually using the --model_type argument.")
|
logger.error("Please specify the type manually using the --model_type argument.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Select the appropriate load_quant function
|
# Select the appropriate load_quant function
|
||||||
@ -148,21 +148,21 @@ def load_quantized(model_name):
|
|||||||
load_quant = llama_inference_offload.load_quant
|
load_quant = llama_inference_offload.load_quant
|
||||||
elif model_type in ('llama', 'opt', 'gptj'):
|
elif model_type in ('llama', 'opt', 'gptj'):
|
||||||
if shared.args.pre_layer:
|
if shared.args.pre_layer:
|
||||||
logging.warning("Ignoring --pre_layer because it only works for llama model type.")
|
logger.warning("Ignoring --pre_layer because it only works for llama model type.")
|
||||||
|
|
||||||
load_quant = _load_quant
|
load_quant = _load_quant
|
||||||
else:
|
else:
|
||||||
logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
logger.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# Find the quantized model weights file (.pt/.safetensors)
|
# Find the quantized model weights file (.pt/.safetensors)
|
||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
pt_path = find_quantized_model_file(model_name)
|
pt_path = find_quantized_model_file(model_name)
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||||
exit()
|
exit()
|
||||||
else:
|
else:
|
||||||
logging.info(f"Found the following quantized model: {pt_path}")
|
logger.info(f"Found the following quantized model: {pt_path}")
|
||||||
|
|
||||||
# qwopqwop200's offload
|
# qwopqwop200's offload
|
||||||
if model_type == 'llama' and shared.args.pre_layer:
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
@ -190,7 +190,7 @@ def load_quantized(model_name):
|
|||||||
max_memory = accelerate.utils.get_balanced_memory(model)
|
max_memory = accelerate.utils.get_balanced_memory(model)
|
||||||
|
|
||||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
||||||
logging.info("Using the following device map for the quantized model:", device_map)
|
logger.info("Using the following device map for the quantized model:", device_map)
|
||||||
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
||||||
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
@ -19,7 +19,7 @@ def add_lora_to_model(lora_names):
|
|||||||
|
|
||||||
# Add a LoRA when another LoRA is already present
|
# Add a LoRA when another LoRA is already present
|
||||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||||
logging.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||||
for lora in added_set:
|
for lora in added_set:
|
||||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ def add_lora_to_model(lora_names):
|
|||||||
shared.model = shared.model.base_model.model
|
shared.model = shared.model.base_model.model
|
||||||
|
|
||||||
if len(lora_names) > 0:
|
if len(lora_names) > 0:
|
||||||
logging.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||||
params = {}
|
params = {}
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
params['dtype'] = shared.model.dtype
|
params['dtype'] = shared.model.dtype
|
||||||
|
@ -3,7 +3,6 @@ import base64
|
|||||||
import copy
|
import copy
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -14,6 +13,7 @@ from PIL import Image
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
||||||
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import (generate_reply, get_encoded_length,
|
from modules.text_generation import (generate_reply, get_encoded_length,
|
||||||
get_max_prompt_length)
|
get_max_prompt_length)
|
||||||
from modules.utils import replace_all
|
from modules.utils import replace_all
|
||||||
@ -187,7 +187,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
output = copy.deepcopy(history)
|
output = copy.deepcopy(history)
|
||||||
output = apply_extensions('history', output)
|
output = apply_extensions('history', output)
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
logging.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield output
|
yield output
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -278,7 +278,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
|
|
||||||
def impersonate_wrapper(text, state):
|
def impersonate_wrapper(text, state):
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
logging.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -584,7 +584,7 @@ def upload_character(json_file, img, tavern=False):
|
|||||||
img = Image.open(io.BytesIO(img))
|
img = Image.open(io.BytesIO(img))
|
||||||
img.save(Path(f'characters/{outfile_name}.png'))
|
img.save(Path(f'characters/{outfile_name}.png'))
|
||||||
|
|
||||||
logging.info(f'New character saved to "characters/{outfile_name}.json".')
|
logger.info(f'New character saved to "characters/{outfile_name}.json".')
|
||||||
return outfile_name
|
return outfile_name
|
||||||
|
|
||||||
|
|
||||||
@ -608,18 +608,18 @@ def upload_your_profile_picture(img):
|
|||||||
else:
|
else:
|
||||||
img = make_thumbnail(img)
|
img = make_thumbnail(img)
|
||||||
img.save(Path('cache/pfp_me.png'))
|
img.save(Path('cache/pfp_me.png'))
|
||||||
logging.info('Profile picture saved to "cache/pfp_me.png"')
|
logger.info('Profile picture saved to "cache/pfp_me.png"')
|
||||||
|
|
||||||
|
|
||||||
def delete_file(path):
|
def delete_file(path):
|
||||||
if path.exists():
|
if path.exists():
|
||||||
logging.warning(f'Deleting {path}')
|
logger.warning(f'Deleting {path}')
|
||||||
path.unlink(missing_ok=True)
|
path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def save_character(name, greeting, context, picture, filename, instruct=False):
|
def save_character(name, greeting, context, picture, filename, instruct=False):
|
||||||
if filename == "":
|
if filename == "":
|
||||||
logging.error("The filename is empty, so the character will not be saved.")
|
logger.error("The filename is empty, so the character will not be saved.")
|
||||||
return
|
return
|
||||||
|
|
||||||
folder = 'characters' if not instruct else 'characters/instruction-following'
|
folder = 'characters' if not instruct else 'characters/instruction-following'
|
||||||
@ -634,11 +634,11 @@ def save_character(name, greeting, context, picture, filename, instruct=False):
|
|||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
yaml.dump(data, f)
|
yaml.dump(data, f)
|
||||||
|
|
||||||
logging.info(f'Wrote {filepath}')
|
logger.info(f'Wrote {filepath}')
|
||||||
path_to_img = Path(f'{folder}/{filename}.png')
|
path_to_img = Path(f'{folder}/{filename}.png')
|
||||||
if picture and not instruct:
|
if picture and not instruct:
|
||||||
picture.save(path_to_img)
|
picture.save(path_to_img)
|
||||||
logging.info(f'Wrote {path_to_img}')
|
logger.info(f'Wrote {path_to_img}')
|
||||||
elif path_to_img.exists():
|
elif path_to_img.exists():
|
||||||
delete_file(path_to_img)
|
delete_file(path_to_img)
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -6,6 +5,7 @@ import gradio as gr
|
|||||||
|
|
||||||
import extensions
|
import extensions
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
available_extensions = []
|
available_extensions = []
|
||||||
@ -29,7 +29,7 @@ def load_extensions():
|
|||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
if name != 'api':
|
if name != 'api':
|
||||||
logging.info(f'Loading the extension "{name}"...')
|
logger.info(f'Loading the extension "{name}"...')
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
extension = getattr(extensions, name).script
|
extension = getattr(extensions, name).script
|
||||||
@ -40,7 +40,7 @@ def load_extensions():
|
|||||||
|
|
||||||
state[name] = [True, i]
|
state[name] = [True, i]
|
||||||
except:
|
except:
|
||||||
logging.error(f'Failed to load the extension "{name}".')
|
logger.error(f'Failed to load the extension "{name}".')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@ -8,21 +7,22 @@ import torch.nn as nn
|
|||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
if shared.args.xformers:
|
if shared.args.xformers:
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
logger.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_attention():
|
def hijack_llama_attention():
|
||||||
if shared.args.xformers:
|
if shared.args.xformers:
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
logging.info("Replaced attention with xformers_attention")
|
logger.info("Replaced attention with xformers_attention")
|
||||||
elif shared.args.sdp_attention:
|
elif shared.args.sdp_attention:
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||||
logging.info("Replaced attention with sdp_attention")
|
logger.info("Replaced attention with sdp_attention")
|
||||||
|
|
||||||
|
|
||||||
def xformers_forward(
|
def xformers_forward(
|
||||||
|
@ -6,13 +6,13 @@ Documentation:
|
|||||||
https://abetlen.github.io/llama-cpp-python/
|
https://abetlen.github.io/llama-cpp-python/
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from llama_cpp import Llama, LlamaCache
|
from llama_cpp import Llama, LlamaCache
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.callbacks import Iteratorize
|
from modules.callbacks import Iteratorize
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppModel:
|
class LlamaCppModel:
|
||||||
@ -35,7 +35,7 @@ class LlamaCppModel:
|
|||||||
else:
|
else:
|
||||||
cache_capacity = int(shared.args.cache_capacity)
|
cache_capacity = int(shared.args.cache_capacity)
|
||||||
|
|
||||||
logging.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'model_path': str(path),
|
'model_path': str(path),
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
|
logging.basicConfig(format='%(levelname)s:%(message)s')
|
||||||
|
|
||||||
|
|
||||||
def add_coloring_to_emit_windows(fn):
|
def add_coloring_to_emit_windows(fn):
|
||||||
# add methods we need to the class
|
# add methods we need to the class
|
||||||
@ -107,3 +109,6 @@ else:
|
|||||||
# log.addFilter(log_filter())
|
# log.addFilter(log_filter())
|
||||||
# //hdlr = logging.StreamHandler()
|
# //hdlr = logging.StreamHandler()
|
||||||
# //hdlr.setFormatter(formatter())
|
# //hdlr.setFormatter(formatter())
|
||||||
|
|
||||||
|
logger = logging.getLogger('text-generation-webui')
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -17,6 +16,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import llama_attn_hijack
|
from modules import llama_attn_hijack
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -71,12 +71,12 @@ def find_model_type(model_name):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model_name):
|
def load_model(model_name):
|
||||||
logging.info(f"Loading {model_name}...")
|
logger.info(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.model_type = find_model_type(model_name)
|
shared.model_type = find_model_type(model_name)
|
||||||
if shared.model_type == 'None':
|
if shared.model_type == 'None':
|
||||||
logging.error('The path to the model does not exist. Exiting.')
|
logger.error('The path to the model does not exist. Exiting.')
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if shared.args.autogptq:
|
if shared.args.autogptq:
|
||||||
@ -106,7 +106,7 @@ def load_model(model_name):
|
|||||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||||
llama_attn_hijack.hijack_llama_attention()
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
|
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ def load_tokenizer(model_name, model):
|
|||||||
if shared.model_type not in ['llava', 'oasst']:
|
if shared.model_type not in ['llava', 'oasst']:
|
||||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||||
if p.exists():
|
if p.exists():
|
||||||
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ def huggingface_loader(model_name):
|
|||||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||||
|
|
||||||
# Custom
|
# Custom
|
||||||
else:
|
else:
|
||||||
@ -172,7 +172,7 @@ def huggingface_loader(model_name):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||||
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
||||||
shared.args.cpu = True
|
shared.args.cpu = True
|
||||||
|
|
||||||
if shared.args.cpu:
|
if shared.args.cpu:
|
||||||
@ -254,7 +254,7 @@ def llamacpp_loader(model_name):
|
|||||||
else:
|
else:
|
||||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
||||||
|
|
||||||
logging.info(f"llama.cpp weights detected: {model_file}\n")
|
logger.info(f"llama.cpp weights detected: {model_file}\n")
|
||||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -263,7 +263,7 @@ def GPTQ_loader(model_name):
|
|||||||
|
|
||||||
# Monkey patch
|
# Monkey patch
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
logger.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
|
||||||
from modules.monkey_patch_gptq_lora import load_model_llama
|
from modules.monkey_patch_gptq_lora import load_model_llama
|
||||||
|
|
||||||
model, _ = load_model_llama(model_name)
|
model, _ = load_model_llama(model_name)
|
||||||
@ -302,7 +302,7 @@ def get_max_memory_dict():
|
|||||||
suggestion -= 1000
|
suggestion -= 1000
|
||||||
|
|
||||||
suggestion = int(round(suggestion / 1000))
|
suggestion = int(round(suggestion / 1000))
|
||||||
logging.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
|
logger.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
|
||||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
||||||
|
|
||||||
return max_memory if len(max_memory) > 0 else None
|
return max_memory if len(max_memory) > 0 else None
|
||||||
@ -333,13 +333,13 @@ def load_soft_prompt(name):
|
|||||||
zf.extract('tensor.npy')
|
zf.extract('tensor.npy')
|
||||||
zf.extract('meta.json')
|
zf.extract('meta.json')
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
j = json.loads(open('meta.json', 'r').read())
|
||||||
logging.info(f"\nLoading the softprompt \"{name}\".")
|
logger.info(f"\nLoading the softprompt \"{name}\".")
|
||||||
for field in j:
|
for field in j:
|
||||||
if field != 'name':
|
if field != 'name':
|
||||||
if type(j[field]) is list:
|
if type(j[field]) is list:
|
||||||
logging.info(f"{field}: {', '.join(j[field])}")
|
logger.info(f"{field}: {', '.join(j[field])}")
|
||||||
else:
|
else:
|
||||||
logging.info(f"{field}: {j[field]}")
|
logger.info(f"{field}: {j[field]}")
|
||||||
|
|
||||||
tensor = np.load('tensor.npy')
|
tensor = np.load('tensor.npy')
|
||||||
Path('tensor.npy').unlink()
|
Path('tensor.npy').unlink()
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = "None"
|
model_name = "None"
|
||||||
@ -180,14 +181,14 @@ args_defaults = parser.parse_args([])
|
|||||||
deprecated_dict = {}
|
deprecated_dict = {}
|
||||||
for k in deprecated_dict:
|
for k in deprecated_dict:
|
||||||
if getattr(args, k) != deprecated_dict[k][1]:
|
if getattr(args, k) != deprecated_dict[k][1]:
|
||||||
logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
logger.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
||||||
setattr(args, deprecated_dict[k][0], getattr(args, k))
|
setattr(args, deprecated_dict[k][0], getattr(args, k))
|
||||||
|
|
||||||
# Security warnings
|
# Security warnings
|
||||||
if args.trust_remote_code:
|
if args.trust_remote_code:
|
||||||
logging.warning("trust_remote_code is enabled. This is dangerous.")
|
logger.warning("trust_remote_code is enabled. This is dangerous.")
|
||||||
if args.share:
|
if args.share:
|
||||||
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
|
logger.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
|
||||||
|
|
||||||
|
|
||||||
def add_extension(name):
|
def add_extension(name):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -14,6 +13,7 @@ from modules.callbacks import (Iteratorize, Stream,
|
|||||||
_SentinelTokenStoppingCriteria)
|
_SentinelTokenStoppingCriteria)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
|
from modules.logging_colors import logger
|
||||||
from modules.models import clear_torch_cache, local_rank
|
from modules.models import clear_torch_cache, local_rank
|
||||||
|
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=None, is_ch
|
|||||||
generate_func = apply_extensions('custom_generate_reply')
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
if generate_func is None:
|
if generate_func is None:
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
logging.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield question
|
yield question
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@ -15,8 +14,9 @@ from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
|
|||||||
set_peft_model_state_dict)
|
set_peft_model_state_dict)
|
||||||
|
|
||||||
from modules import shared, ui, utils
|
from modules import shared, ui, utils
|
||||||
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
|
from modules.evaluate import (calculate_perplexity, generate_markdown_table,
|
||||||
|
save_past_evaluations)
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
# This mapping is from a very recent commit, not yet released.
|
# This mapping is from a very recent commit, not yet released.
|
||||||
# If not available, default to a backup map for some common model types.
|
# If not available, default to a backup map for some common model types.
|
||||||
@ -24,7 +24,8 @@ try:
|
|||||||
from peft.utils.other import \
|
from peft.utils.other import \
|
||||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
||||||
model_to_lora_modules
|
model_to_lora_modules
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
from transformers.models.auto.modeling_auto import \
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
|
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
|
||||||
except:
|
except:
|
||||||
standard_modules = ["q_proj", "v_proj"]
|
standard_modules = ["q_proj", "v_proj"]
|
||||||
@ -217,13 +218,13 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
if model_type == "PeftModelForCausalLM":
|
if model_type == "PeftModelForCausalLM":
|
||||||
if len(shared.args.lora_names) > 0:
|
if len(shared.args.lora_names) > 0:
|
||||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
|
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||||
else:
|
else:
|
||||||
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||||
else:
|
else:
|
||||||
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
@ -233,7 +234,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
|
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
|
||||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||||
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||||
|
|
||||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||||
@ -253,7 +254,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
logging.info("Loading raw text file dataset...")
|
logger.info("Loading raw text file dataset...")
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||||
raw_text = file.read().replace('\r', '')
|
raw_text = file.read().replace('\r', '')
|
||||||
|
|
||||||
@ -311,7 +312,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
prompt = generate_prompt(data_point)
|
prompt = generate_prompt(data_point)
|
||||||
return tokenize(prompt)
|
return tokenize(prompt)
|
||||||
|
|
||||||
logging.info("Loading JSON datasets...")
|
logger.info("Loading JSON datasets...")
|
||||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||||
train_data = data['train'].map(generate_and_tokenize_prompt)
|
train_data = data['train'].map(generate_and_tokenize_prompt)
|
||||||
|
|
||||||
@ -323,10 +324,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logging.info("Getting model ready...")
|
logger.info("Getting model ready...")
|
||||||
prepare_model_for_int8_training(shared.model)
|
prepare_model_for_int8_training(shared.model)
|
||||||
|
|
||||||
logging.info("Prepping for training...")
|
logger.info("Prepping for training...")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
@ -337,10 +338,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("Creating LoRA model...")
|
logger.info("Creating LoRA model...")
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||||
logging.info("Loading existing LoRA data...")
|
logger.info("Loading existing LoRA data...")
|
||||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
||||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
except:
|
except:
|
||||||
@ -418,7 +419,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
||||||
|
|
||||||
# == Main run and monitor loop ==
|
# == Main run and monitor loop ==
|
||||||
logging.info("Starting training...")
|
logger.info("Starting training...")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
yield "Interrupted before start."
|
yield "Interrupted before start."
|
||||||
@ -428,7 +429,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
logging.info("LoRA training run is completed and saved.")
|
logger.info("LoRA training run is completed and saved.")
|
||||||
tracked.did_save = True
|
tracked.did_save = True
|
||||||
|
|
||||||
thread = threading.Thread(target=threaded_run)
|
thread = threading.Thread(target=threaded_run)
|
||||||
@ -460,14 +461,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||||
if not tracked.did_save:
|
if not tracked.did_save:
|
||||||
logging.info("Training complete, saving...")
|
logger.info("Training complete, saving...")
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
logging.info("Training interrupted.")
|
logger.info("Training interrupted.")
|
||||||
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
|
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
|
||||||
else:
|
else:
|
||||||
logging.info("Training complete!")
|
logger.info("Training complete!")
|
||||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||||
|
|
||||||
|
|
||||||
|
18
server.py
18
server.py
@ -1,17 +1,18 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
import warnings
|
import warnings
|
||||||
import modules.logging_colors
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
|
|
||||||
|
|
||||||
# This is a hack to prevent Gradio from phoning home when it gets imported
|
# This is a hack to prevent Gradio from phoning home when it gets imported
|
||||||
def my_get(url, **kwargs):
|
def my_get(url, **kwargs):
|
||||||
logging.info('Gradio HTTP request redirected to localhost :)')
|
logger.info('Gradio HTTP request redirected to localhost :)')
|
||||||
kwargs.setdefault('allow_redirects', True)
|
kwargs.setdefault('allow_redirects', True)
|
||||||
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
||||||
|
|
||||||
@ -49,7 +50,8 @@ from modules.extensions import apply_extensions
|
|||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt, unload_model
|
from modules.models import load_model, load_soft_prompt, unload_model
|
||||||
from modules.text_generation import generate_reply_wrapper, get_encoded_length, stop_everything_event
|
from modules.text_generation import (generate_reply_wrapper,
|
||||||
|
get_encoded_length, stop_everything_event)
|
||||||
|
|
||||||
|
|
||||||
def load_model_wrapper(selected_model, autoload=False):
|
def load_model_wrapper(selected_model, autoload=False):
|
||||||
@ -971,7 +973,7 @@ if __name__ == "__main__":
|
|||||||
settings_file = Path('settings.json')
|
settings_file = Path('settings.json')
|
||||||
|
|
||||||
if settings_file is not None:
|
if settings_file is not None:
|
||||||
logging.info(f"Loading settings from {settings_file}...")
|
logger.info(f"Loading settings from {settings_file}...")
|
||||||
new_settings = json.loads(open(settings_file, 'r').read())
|
new_settings = json.loads(open(settings_file, 'r').read())
|
||||||
for item in new_settings:
|
for item in new_settings:
|
||||||
shared.settings[item] = new_settings[item]
|
shared.settings[item] = new_settings[item]
|
||||||
@ -1015,7 +1017,7 @@ if __name__ == "__main__":
|
|||||||
# Select the model from a command-line menu
|
# Select the model from a command-line menu
|
||||||
elif shared.args.model_menu:
|
elif shared.args.model_menu:
|
||||||
if len(available_models) == 0:
|
if len(available_models) == 0:
|
||||||
logging.error('No models are available! Please download at least one.')
|
logger.error('No models are available! Please download at least one.')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
print('The following models are available:\n')
|
print('The following models are available:\n')
|
||||||
|
Loading…
Reference in New Issue
Block a user