mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
Improve the imports
This commit is contained in:
parent
364529d0c7
commit
7224343a70
@ -3,6 +3,7 @@
|
|||||||
Converts a transformers model to a format compatible with flexgen.
|
Converts a transformers model to a format compatible with flexgen.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,8 +11,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||||
@ -31,7 +31,6 @@ def disable_torch_init():
|
|||||||
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
||||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
def restore_torch_init():
|
def restore_torch_init():
|
||||||
"""Rollback the change made by disable_torch_init."""
|
"""Rollback the change made by disable_torch_init."""
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,12 +10,12 @@ Based on the original script by 81300:
|
|||||||
https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
|
https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from transformers import BlipForConditionalGeneration
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||||
from transformers import BlipProcessor
|
|
||||||
|
|
||||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||||
|
@ -7,13 +7,12 @@ from datetime import datetime
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_chat_html
|
from modules.html_generator import generate_chat_html
|
||||||
from modules.text_generation import encode
|
from modules.text_generation import encode, generate_reply, get_max_prompt_length
|
||||||
from modules.text_generation import generate_reply
|
|
||||||
from modules.text_generation import get_max_prompt_length
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
|
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
|
||||||
import modules.bot_picture as bot_picture
|
import modules.bot_picture as bot_picture
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
import extensions
|
import extensions
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
extension_state = {}
|
extension_state = {}
|
||||||
available_extensions = []
|
available_extensions = []
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
|
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -4,23 +4,27 @@ import time
|
|||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
local_rank = None
|
local_rank = None
|
||||||
|
|
||||||
if shared.args.flexgen:
|
if shared.args.flexgen:
|
||||||
from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config)
|
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
|
||||||
|
TorchDevice, TorchDisk, TorchMixedDevice,
|
||||||
|
get_opt_config)
|
||||||
|
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
|
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||||
|
is_deepspeed_zero3_enabled)
|
||||||
|
|
||||||
from modules.deepspeed_parameters import generate_ds_config
|
from modules.deepspeed_parameters import generate_ds_config
|
||||||
|
|
||||||
# Distributed setup
|
# Distributed setup
|
||||||
|
@ -4,9 +4,11 @@ This code was copied from
|
|||||||
https://github.com/PygmalionAI/gradio-ui/
|
https://github.com/PygmalionAI/gradio-ui/
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||||
|
|
||||||
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.html_generator import generate_basic_html
|
|
||||||
from modules.models import local_rank
|
from modules.models import local_rank
|
||||||
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
def get_max_prompt_length(tokens):
|
def get_max_prompt_length(tokens):
|
||||||
max_length = 2048-tokens
|
max_length = 2048-tokens
|
||||||
|
@ -14,12 +14,9 @@ import modules.chat as chat
|
|||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.ui as ui
|
import modules.ui as ui
|
||||||
from modules.extensions import extension_state
|
from modules.extensions import extension_state, load_extensions, update_extensions_parameters
|
||||||
from modules.extensions import load_extensions
|
|
||||||
from modules.extensions import update_extensions_parameters
|
|
||||||
from modules.html_generator import generate_chat_html
|
from modules.html_generator import generate_chat_html
|
||||||
from modules.models import load_model
|
from modules.models import load_model, load_soft_prompt
|
||||||
from modules.models import load_soft_prompt
|
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
|
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
|
||||||
|
Loading…
Reference in New Issue
Block a user