Prevent unwanted log messages from modules

This commit is contained in:
oobabooga 2023-05-21 22:42:34 -03:00
parent fb91406e93
commit e116d31180
20 changed files with 120 additions and 111 deletions

View File

@ -1,8 +1,8 @@
import logging
import gradio as gr import gradio as gr
from modules.logging_colors import logger
def ui(): def ui():
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead") gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead") logger.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View File

@ -1,5 +1,4 @@
import base64 import base64
import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
@ -10,6 +9,7 @@ from PIL import Image
from extensions.multimodal.pipeline_loader import load_pipeline from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared from modules import shared
from modules.logging_colors import logger
from modules.text_generation import encode, get_max_prompt_length from modules.text_generation import encode, get_max_prompt_length
@ -26,7 +26,7 @@ class MultimodalEmbedder:
def __init__(self, params: dict): def __init__(self, params: dict):
pipeline, source = load_pipeline(params) pipeline, source = load_pipeline(params)
self.pipeline = pipeline self.pipeline = pipeline
logging.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})') logger.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]: def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
"""Splits a prompt into a list of `PromptParts` to separate image data from text. """Splits a prompt into a list of `PromptParts` to separate image data from text.
@ -138,7 +138,7 @@ class MultimodalEmbedder:
# notify user if we truncated an image # notify user if we truncated an image
if removed_images > 0: if removed_images > 0:
logging.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken") logger.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
return encoded return encoded

View File

@ -1,4 +1,3 @@
import logging
import traceback import traceback
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
@ -6,6 +5,7 @@ from typing import Tuple
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared from modules import shared
from modules.logging_colors import logger
def _get_available_pipeline_modules(): def _get_available_pipeline_modules():
@ -21,8 +21,8 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
try: try:
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines') pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
except: except:
logging.warning(f'Failed to get multimodal pipelines from {name}') logger.warning(f'Failed to get multimodal pipelines from {name}')
logging.warning(traceback.format_exc()) logger.warning(traceback.format_exc())
if shared.args.multimodal_pipeline is not None: if shared.args.multimodal_pipeline is not None:
for k in pipeline_modules: for k in pipeline_modules:
@ -48,5 +48,5 @@ def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.' log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
else: else:
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.' log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension') logger.critical(f'{log} Please specify a correct pipeline, or disable the extension')
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension') raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')

View File

@ -1,16 +1,17 @@
import logging
import time import time
from abc import abstractmethod from abc import abstractmethod
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from modules import shared
from modules.text_generation import encode
from PIL import Image from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel from transformers import CLIPImageProcessor, CLIPVisionModel
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import encode
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline): class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
CLIP_REPO = "openai/clip-vit-large-patch14" CLIP_REPO = "openai/clip-vit-large-patch14"
@ -26,11 +27,11 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
def _load_models(self): def _load_models(self):
start_ts = time.time() start_ts = time.time()
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...") logger.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype) image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device) vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...") logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename()) projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
mm_projector = torch.nn.Linear(*self.llava_projector_shape()) mm_projector = torch.nn.Linear(*self.llava_projector_shape())
projector_data = torch.load(projector_path) projector_data = torch.load(projector_path)
@ -38,7 +39,7 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False) mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
mm_projector = mm_projector.to(self.projector_device) mm_projector = mm_projector.to(self.projector_device)
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds") logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector return image_processor, vision_tower, mm_projector
@staticmethod @staticmethod

View File

@ -1,5 +1,4 @@
import base64 import base64
import logging
import re import re
import time import time
from functools import partial from functools import partial
@ -10,6 +9,7 @@ import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared from modules import shared
from modules.logging_colors import logger
params = { params = {
"add_all_images_to_prompt": False, "add_all_images_to_prompt": False,
@ -78,7 +78,7 @@ def tokenizer_modifier(state, prompt, input_ids, input_embeds):
return prompt, input_ids, input_embeds return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt, return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))

View File

@ -1,13 +1,12 @@
import logging import chromadb
import posthog import posthog
import torch import torch
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
import chromadb from modules.logging_colors import logger
from chromadb.config import Settings
logging.info('Intercepting all calls to posthog :)') logger.info('Intercepting all calls to posthog :)')
posthog.capture = lambda *args, **kwargs: None posthog.capture = lambda *args, **kwargs: None

View File

@ -1,4 +1,3 @@
import logging
import re import re
import textwrap import textwrap
@ -6,6 +5,7 @@ import gradio as gr
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from modules import chat, shared from modules import chat, shared
from modules.logging_colors import logger
from .chromadb import add_chunks_to_collector, make_collector from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls from .download_urls import download_urls
@ -123,14 +123,14 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_) additional_context += make_single_exchange(id_)
logging.warning(f'Adding the following new context:\n{additional_context}') logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = { kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids], 'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': '' 'visible': ''
} }
except RuntimeError: except RuntimeError:
logging.error("Couldn't query the database, moving on...") logger.error("Couldn't query the database, moving on...")
return chat.generate_chat_prompt(user_input, state, **kwargs) return chat.generate_chat_prompt(user_input, state, **kwargs)

View File

@ -1,9 +1,9 @@
import logging
from pathlib import Path from pathlib import Path
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger
from modules.models import get_max_memory_dict from modules.models import get_max_memory_dict
@ -17,13 +17,13 @@ def load_quantized(model_name):
found = list(path_to_model.glob(f"*{ext}")) found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0: if len(found) > 0:
if len(found) > 1: if len(found) > 1:
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
pt_path = found[-1] pt_path = found[-1]
break break
if pt_path is None: if pt_path is None:
logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.") logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
return return
# Define the params for AutoGPTQForCausalLM.from_quantized # Define the params for AutoGPTQForCausalLM.from_quantized
@ -35,6 +35,6 @@ def load_quantized(model_name):
'max_memory': get_max_memory_dict() 'max_memory': get_max_memory_dict()
} }
logging.warning(f"The AutoGPTQ params are: {params}") logger.warning(f"The AutoGPTQ params are: {params}")
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
return model return model

View File

@ -1,5 +1,4 @@
import inspect import inspect
import logging
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
@ -10,14 +9,15 @@ import transformers
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
try: try:
import llama_inference_offload import llama_inference_offload
except ImportError: except ImportError:
logging.error('Failed to load GPTQ-for-LLaMa') logger.error('Failed to load GPTQ-for-LLaMa')
logging.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md') logger.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md')
sys.exit(-1) sys.exit(-1)
try: try:
@ -127,7 +127,7 @@ def find_quantized_model_file(model_name):
found = list(path_to_model.glob(f"*{ext}")) found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0: if len(found) > 0:
if len(found) > 1: if len(found) > 1:
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
pt_path = found[-1] pt_path = found[-1]
break break
@ -138,8 +138,8 @@ def find_quantized_model_file(model_name):
# The function that loads the model in modules/models.py # The function that loads the model in modules/models.py
def load_quantized(model_name): def load_quantized(model_name):
if shared.args.model_type is None: if shared.args.model_type is None:
logging.error("The model could not be loaded because its type could not be inferred from its name.") logger.error("The model could not be loaded because its type could not be inferred from its name.")
logging.error("Please specify the type manually using the --model_type argument.") logger.error("Please specify the type manually using the --model_type argument.")
return None return None
# Select the appropriate load_quant function # Select the appropriate load_quant function
@ -148,21 +148,21 @@ def load_quantized(model_name):
load_quant = llama_inference_offload.load_quant load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'): elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer: if shared.args.pre_layer:
logging.warning("Ignoring --pre_layer because it only works for llama model type.") logger.warning("Ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant load_quant = _load_quant
else: else:
logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") logger.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit() exit()
# Find the quantized model weights file (.pt/.safetensors) # Find the quantized model weights file (.pt/.safetensors)
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name) pt_path = find_quantized_model_file(model_name)
if not pt_path: if not pt_path:
logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...") logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit() exit()
else: else:
logging.info(f"Found the following quantized model: {pt_path}") logger.info(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload # qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer: if model_type == 'llama' and shared.args.pre_layer:
@ -190,7 +190,7 @@ def load_quantized(model_name):
max_memory = accelerate.utils.get_balanced_memory(model) max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
logging.info("Using the following device map for the quantized model:", device_map) logger.info("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True) model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)

View File

@ -1,10 +1,10 @@
import logging
from pathlib import Path from pathlib import Path
import torch import torch
from peft import PeftModel from peft import PeftModel
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger
def add_lora_to_model(lora_names): def add_lora_to_model(lora_names):
@ -19,7 +19,7 @@ def add_lora_to_model(lora_names):
# Add a LoRA when another LoRA is already present # Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0: if len(removed_set) == 0 and len(prior_set) > 0:
logging.info(f"Adding the LoRA(s) named {added_set} to the model...") logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
for lora in added_set: for lora in added_set:
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
@ -31,7 +31,7 @@ def add_lora_to_model(lora_names):
shared.model = shared.model.base_model.model shared.model = shared.model.base_model.model
if len(lora_names) > 0: if len(lora_names) > 0:
logging.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
params = {} params = {}
if not shared.args.cpu: if not shared.args.cpu:
params['dtype'] = shared.model.dtype params['dtype'] = shared.model.dtype

View File

@ -3,7 +3,6 @@ import base64
import copy import copy
import io import io
import json import json
import logging
import re import re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -14,6 +13,7 @@ from PIL import Image
import modules.shared as shared import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper, make_thumbnail from modules.html_generator import chat_html_wrapper, make_thumbnail
from modules.logging_colors import logger
from modules.text_generation import (generate_reply, get_encoded_length, from modules.text_generation import (generate_reply, get_encoded_length,
get_max_prompt_length) get_max_prompt_length)
from modules.utils import replace_all from modules.utils import replace_all
@ -187,7 +187,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
output = copy.deepcopy(history) output = copy.deepcopy(history)
output = apply_extensions('history', output) output = apply_extensions('history', output)
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
yield output yield output
return return
@ -278,7 +278,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
def impersonate_wrapper(text, state): def impersonate_wrapper(text, state):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
yield '' yield ''
return return
@ -584,7 +584,7 @@ def upload_character(json_file, img, tavern=False):
img = Image.open(io.BytesIO(img)) img = Image.open(io.BytesIO(img))
img.save(Path(f'characters/{outfile_name}.png')) img.save(Path(f'characters/{outfile_name}.png'))
logging.info(f'New character saved to "characters/{outfile_name}.json".') logger.info(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name
@ -608,18 +608,18 @@ def upload_your_profile_picture(img):
else: else:
img = make_thumbnail(img) img = make_thumbnail(img)
img.save(Path('cache/pfp_me.png')) img.save(Path('cache/pfp_me.png'))
logging.info('Profile picture saved to "cache/pfp_me.png"') logger.info('Profile picture saved to "cache/pfp_me.png"')
def delete_file(path): def delete_file(path):
if path.exists(): if path.exists():
logging.warning(f'Deleting {path}') logger.warning(f'Deleting {path}')
path.unlink(missing_ok=True) path.unlink(missing_ok=True)
def save_character(name, greeting, context, picture, filename, instruct=False): def save_character(name, greeting, context, picture, filename, instruct=False):
if filename == "": if filename == "":
logging.error("The filename is empty, so the character will not be saved.") logger.error("The filename is empty, so the character will not be saved.")
return return
folder = 'characters' if not instruct else 'characters/instruction-following' folder = 'characters' if not instruct else 'characters/instruction-following'
@ -634,11 +634,11 @@ def save_character(name, greeting, context, picture, filename, instruct=False):
with filepath.open('w') as f: with filepath.open('w') as f:
yaml.dump(data, f) yaml.dump(data, f)
logging.info(f'Wrote {filepath}') logger.info(f'Wrote {filepath}')
path_to_img = Path(f'{folder}/{filename}.png') path_to_img = Path(f'{folder}/{filename}.png')
if picture and not instruct: if picture and not instruct:
picture.save(path_to_img) picture.save(path_to_img)
logging.info(f'Wrote {path_to_img}') logger.info(f'Wrote {path_to_img}')
elif path_to_img.exists(): elif path_to_img.exists():
delete_file(path_to_img) delete_file(path_to_img)

View File

@ -1,4 +1,3 @@
import logging
import traceback import traceback
from functools import partial from functools import partial
@ -6,6 +5,7 @@ import gradio as gr
import extensions import extensions
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger
state = {} state = {}
available_extensions = [] available_extensions = []
@ -29,7 +29,7 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api': if name != 'api':
logging.info(f'Loading the extension "{name}"...') logger.info(f'Loading the extension "{name}"...')
try: try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script extension = getattr(extensions, name).script
@ -40,7 +40,7 @@ def load_extensions():
state[name] = [True, i] state[name] = [True, i]
except: except:
logging.error(f'Failed to load the extension "{name}".') logger.error(f'Failed to load the extension "{name}".')
traceback.print_exc() traceback.print_exc()

View File

@ -1,4 +1,3 @@
import logging
import math import math
import sys import sys
from typing import Optional, Tuple from typing import Optional, Tuple
@ -8,21 +7,22 @@ import torch.nn as nn
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger
if shared.args.xformers: if shared.args.xformers:
try: try:
import xformers.ops import xformers.ops
except Exception: except Exception:
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr) logger.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention(): def hijack_llama_attention():
if shared.args.xformers: if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
logging.info("Replaced attention with xformers_attention") logger.info("Replaced attention with xformers_attention")
elif shared.args.sdp_attention: elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
logging.info("Replaced attention with sdp_attention") logger.info("Replaced attention with sdp_attention")
def xformers_forward( def xformers_forward(

View File

@ -6,13 +6,13 @@ Documentation:
https://abetlen.github.io/llama-cpp-python/ https://abetlen.github.io/llama-cpp-python/
''' '''
import logging
import re import re
from llama_cpp import Llama, LlamaCache from llama_cpp import Llama, LlamaCache
from modules import shared from modules import shared
from modules.callbacks import Iteratorize from modules.callbacks import Iteratorize
from modules.logging_colors import logger
class LlamaCppModel: class LlamaCppModel:
@ -35,7 +35,7 @@ class LlamaCppModel:
else: else:
cache_capacity = int(shared.args.cache_capacity) cache_capacity = int(shared.args.cache_capacity)
logging.info("Cache capacity is " + str(cache_capacity) + " bytes") logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
params = { params = {
'model_path': str(path), 'model_path': str(path),

View File

@ -3,6 +3,8 @@
import logging import logging
import platform import platform
logging.basicConfig(format='%(levelname)s:%(message)s')
def add_coloring_to_emit_windows(fn): def add_coloring_to_emit_windows(fn):
# add methods we need to the class # add methods we need to the class
@ -107,3 +109,6 @@ else:
# log.addFilter(log_filter()) # log.addFilter(log_filter())
# //hdlr = logging.StreamHandler() # //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter()) # //hdlr.setFormatter(formatter())
logger = logging.getLogger('text-generation-webui')
logger.setLevel(logging.DEBUG)

View File

@ -1,6 +1,5 @@
import gc import gc
import json import json
import logging
import os import os
import re import re
import time import time
@ -17,6 +16,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
import modules.shared as shared import modules.shared as shared
from modules import llama_attn_hijack from modules import llama_attn_hijack
from modules.logging_colors import logger
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -71,12 +71,12 @@ def find_model_type(model_name):
def load_model(model_name): def load_model(model_name):
logging.info(f"Loading {model_name}...") logger.info(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
shared.model_type = find_model_type(model_name) shared.model_type = find_model_type(model_name)
if shared.model_type == 'None': if shared.model_type == 'None':
logging.error('The path to the model does not exist. Exiting.') logger.error('The path to the model does not exist. Exiting.')
return None, None return None, None
if shared.args.autogptq: if shared.args.autogptq:
@ -106,7 +106,7 @@ def load_model(model_name):
if any((shared.args.xformers, shared.args.sdp_attention)): if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention() llama_attn_hijack.hijack_llama_attention()
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n") logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
return model, tokenizer return model, tokenizer
@ -119,7 +119,7 @@ def load_tokenizer(model_name, model):
if shared.model_type not in ['llava', 'oasst']: if shared.model_type not in ['llava', 'oasst']:
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists(): if p.exists():
logging.info(f"Loading the universal LLaMA tokenizer from {p}...") logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
return tokenizer return tokenizer
@ -162,7 +162,7 @@ def huggingface_loader(model_name):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference model.module.eval() # Inference
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# Custom # Custom
else: else:
@ -172,7 +172,7 @@ def huggingface_loader(model_name):
} }
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
shared.args.cpu = True shared.args.cpu = True
if shared.args.cpu: if shared.args.cpu:
@ -254,7 +254,7 @@ def llamacpp_loader(model_name):
else: else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0] model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
logging.info(f"llama.cpp weights detected: {model_file}\n") logger.info(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer return model, tokenizer
@ -263,7 +263,7 @@ def GPTQ_loader(model_name):
# Monkey patch # Monkey patch
if shared.args.monkey_patch: if shared.args.monkey_patch:
logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.") logger.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
from modules.monkey_patch_gptq_lora import load_model_llama from modules.monkey_patch_gptq_lora import load_model_llama
model, _ = load_model_llama(model_name) model, _ = load_model_llama(model_name)
@ -302,7 +302,7 @@ def get_max_memory_dict():
suggestion -= 1000 suggestion -= 1000
suggestion = int(round(suggestion / 1000)) suggestion = int(round(suggestion / 1000))
logging.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.") logger.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
return max_memory if len(max_memory) > 0 else None return max_memory if len(max_memory) > 0 else None
@ -333,13 +333,13 @@ def load_soft_prompt(name):
zf.extract('tensor.npy') zf.extract('tensor.npy')
zf.extract('meta.json') zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read()) j = json.loads(open('meta.json', 'r').read())
logging.info(f"\nLoading the softprompt \"{name}\".") logger.info(f"\nLoading the softprompt \"{name}\".")
for field in j: for field in j:
if field != 'name': if field != 'name':
if type(j[field]) is list: if type(j[field]) is list:
logging.info(f"{field}: {', '.join(j[field])}") logger.info(f"{field}: {', '.join(j[field])}")
else: else:
logging.info(f"{field}: {j[field]}") logger.info(f"{field}: {j[field]}")
tensor = np.load('tensor.npy') tensor = np.load('tensor.npy')
Path('tensor.npy').unlink() Path('tensor.npy').unlink()

View File

@ -1,10 +1,11 @@
import argparse import argparse
import logging
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
import yaml import yaml
from modules.logging_colors import logger
model = None model = None
tokenizer = None tokenizer = None
model_name = "None" model_name = "None"
@ -180,14 +181,14 @@ args_defaults = parser.parse_args([])
deprecated_dict = {} deprecated_dict = {}
for k in deprecated_dict: for k in deprecated_dict:
if getattr(args, k) != deprecated_dict[k][1]: if getattr(args, k) != deprecated_dict[k][1]:
logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") logger.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
setattr(args, deprecated_dict[k][0], getattr(args, k)) setattr(args, deprecated_dict[k][0], getattr(args, k))
# Security warnings # Security warnings
if args.trust_remote_code: if args.trust_remote_code:
logging.warning("trust_remote_code is enabled. This is dangerous.") logger.warning("trust_remote_code is enabled. This is dangerous.")
if args.share: if args.share:
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.") logger.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
def add_extension(name): def add_extension(name):

View File

@ -1,5 +1,4 @@
import ast import ast
import logging
import random import random
import re import re
import time import time
@ -14,6 +13,7 @@ from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria) _SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank from modules.models import clear_torch_cache, local_rank
@ -159,7 +159,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=None, is_ch
generate_func = apply_extensions('custom_generate_reply') generate_func = apply_extensions('custom_generate_reply')
if generate_func is None: if generate_func is None:
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
yield question yield question
return return

View File

@ -1,5 +1,4 @@
import json import json
import logging
import math import math
import sys import sys
import threading import threading
@ -15,8 +14,9 @@ from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
set_peft_model_state_dict) set_peft_model_state_dict)
from modules import shared, ui, utils from modules import shared, ui, utils
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations from modules.evaluate import (calculate_perplexity, generate_markdown_table,
save_past_evaluations)
from modules.logging_colors import logger
# This mapping is from a very recent commit, not yet released. # This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for some common model types. # If not available, default to a backup map for some common model types.
@ -24,7 +24,8 @@ try:
from peft.utils.other import \ from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules model_to_lora_modules
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
except: except:
standard_modules = ["q_proj", "v_proj"] standard_modules = ["q_proj", "v_proj"]
@ -217,13 +218,13 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if model_type == "PeftModelForCausalLM": if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0: if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.") logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else: else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.") logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else: else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})") logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5) time.sleep(5)
@ -233,7 +234,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif not shared.args.load_in_8bit and shared.args.wbits <= 0: elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.") logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
@ -253,7 +254,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
logging.info("Loading raw text file dataset...") logger.info("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '') raw_text = file.read().replace('\r', '')
@ -311,7 +312,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
prompt = generate_prompt(data_point) prompt = generate_prompt(data_point)
return tokenize(prompt) return tokenize(prompt)
logging.info("Loading JSON datasets...") logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt) train_data = data['train'].map(generate_and_tokenize_prompt)
@ -323,10 +324,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logging.info("Getting model ready...") logger.info("Getting model ready...")
prepare_model_for_int8_training(shared.model) prepare_model_for_int8_training(shared.model)
logging.info("Prepping for training...") logger.info("Prepping for training...")
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
@ -337,10 +338,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
) )
try: try:
logging.info("Creating LoRA model...") logger.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logging.info("Loading existing LoRA data...") logger.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin") state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
except: except:
@ -418,7 +419,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
json.dump({x: vars[x] for x in PARAMETERS}, file) json.dump({x: vars[x] for x in PARAMETERS}, file)
# == Main run and monitor loop == # == Main run and monitor loop ==
logging.info("Starting training...") logger.info("Starting training...")
yield "Starting..." yield "Starting..."
if WANT_INTERRUPT: if WANT_INTERRUPT:
yield "Interrupted before start." yield "Interrupted before start."
@ -428,7 +429,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
trainer.train() trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed) # Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
logging.info("LoRA training run is completed and saved.") logger.info("LoRA training run is completed and saved.")
tracked.did_save = True tracked.did_save = True
thread = threading.Thread(target=threaded_run) thread = threading.Thread(target=threaded_run)
@ -460,14 +461,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
logging.info("Training complete, saving...") logger.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: if WANT_INTERRUPT:
logging.info("Training interrupted.") logger.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`" yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
else: else:
logging.info("Training complete!") logger.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`" yield f"Done! LoRA saved to `{lora_file_path}`"

View File

@ -1,17 +1,18 @@
import logging
import os import os
import requests
import warnings import warnings
import modules.logging_colors
import requests
from modules.logging_colors import logger
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1' os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
# This is a hack to prevent Gradio from phoning home when it gets imported # This is a hack to prevent Gradio from phoning home when it gets imported
def my_get(url, **kwargs): def my_get(url, **kwargs):
logging.info('Gradio HTTP request redirected to localhost :)') logger.info('Gradio HTTP request redirected to localhost :)')
kwargs.setdefault('allow_redirects', True) kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs) return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
@ -49,7 +50,8 @@ from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import generate_reply_wrapper, get_encoded_length, stop_everything_event from modules.text_generation import (generate_reply_wrapper,
get_encoded_length, stop_everything_event)
def load_model_wrapper(selected_model, autoload=False): def load_model_wrapper(selected_model, autoload=False):
@ -971,7 +973,7 @@ if __name__ == "__main__":
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
logging.info(f"Loading settings from {settings_file}...") logger.info(f"Loading settings from {settings_file}...")
new_settings = json.loads(open(settings_file, 'r').read()) new_settings = json.loads(open(settings_file, 'r').read())
for item in new_settings: for item in new_settings:
shared.settings[item] = new_settings[item] shared.settings[item] = new_settings[item]
@ -1015,7 +1017,7 @@ if __name__ == "__main__":
# Select the model from a command-line menu # Select the model from a command-line menu
elif shared.args.model_menu: elif shared.args.model_menu:
if len(available_models) == 0: if len(available_models) == 0:
logging.error('No models are available! Please download at least one.') logger.error('No models are available! Please download at least one.')
sys.exit(0) sys.exit(0)
else: else:
print('The following models are available:\n') print('The following models are available:\n')