Improve the imports

This commit is contained in:
oobabooga 2023-02-23 14:41:42 -03:00
parent 364529d0c7
commit 7224343a70
10 changed files with 30 additions and 29 deletions

View File

@ -3,6 +3,7 @@
Converts a transformers model to a format compatible with flexgen. Converts a transformers model to a format compatible with flexgen.
''' '''
import argparse import argparse
import os import os
from pathlib import Path from pathlib import Path
@ -10,8 +11,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
@ -31,7 +31,6 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init(): def restore_torch_init():
"""Rollback the change made by disable_torch_init.""" """Rollback the change made by disable_torch_init."""
import torch import torch

View File

@ -10,12 +10,12 @@ Based on the original script by 81300:
https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303 https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
''' '''
import argparse import argparse
from pathlib import Path from pathlib import Path
import torch import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")

View File

@ -1,6 +1,5 @@
import torch import torch
from transformers import BlipForConditionalGeneration from transformers import BlipForConditionalGeneration, BlipProcessor
from transformers import BlipProcessor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")

View File

@ -7,13 +7,12 @@ from datetime import datetime
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from PIL import Image
import modules.shared as shared import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.text_generation import encode from modules.text_generation import encode, generate_reply, get_max_prompt_length
from modules.text_generation import generate_reply
from modules.text_generation import get_max_prompt_length
from PIL import Image
if shared.args.picture and (shared.args.cai_chat or shared.args.chat): if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture import modules.bot_picture as bot_picture

View File

@ -1,6 +1,5 @@
import modules.shared as shared
import extensions import extensions
import modules.shared as shared
extension_state = {} extension_state = {}
available_extensions = [] available_extensions = []

View File

@ -3,6 +3,7 @@
This is a library for formatting GPT-4chan and chat outputs as nice HTML. This is a library for formatting GPT-4chan and chat outputs as nice HTML.
''' '''
import base64 import base64
import os import os
import re import re

View File

@ -4,23 +4,27 @@ import time
import zipfile import zipfile
from pathlib import Path from pathlib import Path
import modules.shared as shared
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
import modules.shared as shared
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
local_rank = None local_rank = None
if shared.args.flexgen: if shared.args.flexgen:
from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config) from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
TorchDevice, TorchDisk, TorchMixedDevice,
get_opt_config)
if shared.args.deepspeed: if shared.args.deepspeed:
import deepspeed import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled from transformers.deepspeed import (HfDeepSpeedConfig,
is_deepspeed_zero3_enabled)
from modules.deepspeed_parameters import generate_ds_config from modules.deepspeed_parameters import generate_ds_config
# Distributed setup # Distributed setup

View File

@ -4,9 +4,11 @@ This code was copied from
https://github.com/PygmalionAI/gradio-ui/ https://github.com/PygmalionAI/gradio-ui/
''' '''
import torch import torch
import transformers import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor, def __init__(self, sentinel_token_ids: torch.LongTensor,

View File

@ -1,16 +1,17 @@
import re import re
import time import time
import modules.shared as shared
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from tqdm import tqdm
import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.html_generator import generate_basic_html
from modules.models import local_rank from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria from modules.stopping_criteria import _SentinelTokenStoppingCriteria
from tqdm import tqdm
def get_max_prompt_length(tokens): def get_max_prompt_length(tokens):
max_length = 2048-tokens max_length = 2048-tokens

View File

@ -14,12 +14,9 @@ import modules.chat as chat
import modules.extensions as extensions_module import modules.extensions as extensions_module
import modules.shared as shared import modules.shared as shared
import modules.ui as ui import modules.ui as ui
from modules.extensions import extension_state from modules.extensions import extension_state, load_extensions, update_extensions_parameters
from modules.extensions import load_extensions
from modules.extensions import update_extensions_parameters
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.models import load_model from modules.models import load_model, load_soft_prompt
from modules.models import load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: