Merge pull request #110 from oobabooga/refactored

Refactor everything
This commit is contained in:
oobabooga 2023-02-23 15:30:32 -03:00 committed by GitHub
commit 682f7bdbba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 948 additions and 920 deletions

View File

@ -10,7 +10,6 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely. allowing you to use the API remotely.
''' '''
import requests import requests
# Server address # Server address

View File

@ -6,14 +6,12 @@ Converts a transformers model to a format compatible with flexgen.
import argparse import argparse
import os import os
import numpy as np
from pathlib import Path from pathlib import Path
from sys import argv
import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
@ -33,7 +31,6 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init(): def restore_torch_init():
"""Rollback the change made by disable_torch_init.""" """Rollback the change made by disable_torch_init."""
import torch import torch

View File

@ -13,11 +13,9 @@ https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
import argparse import argparse
from pathlib import Path from pathlib import Path
from sys import argv
import torch import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")

View File

@ -1,8 +1,5 @@
import requests
import torch import torch
from PIL import Image from transformers import BlipForConditionalGeneration, BlipProcessor
from transformers import BlipForConditionalGeneration
from transformers import BlipProcessor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")

366
modules/chat.py Normal file
View File

@ -0,0 +1,366 @@
import base64
import copy
import io
import json
import re
from datetime import datetime
from io import BytesIO
from pathlib import Path
from PIL import Image
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode, generate_reply, get_max_prompt_length
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture
# This gets the new line characters right.
def clean_chat_message(text):
text = text.replace('\n', '\n\n')
text = re.sub(r"\n{3,}", "\n\n", text)
text = text.strip()
return text
def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
text = clean_chat_message(text)
rows = [f"{context.strip()}\n"]
i = len(shared.history['internal'])-1
count = 0
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
count += 1
if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n")
count += 1
i -= 1
if not impersonate:
rows.append(f"{name1}: {text}\n")
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
limit = 3
else:
rows.append(f"{name1}:")
limit = 2
while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length:
rows.pop(1)
rows.pop(1)
question = ''.join(rows)
return question
def extract_message_from_reply(question, reply, current, other, check, extensions=False):
next_character_found = False
substring_found = False
previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", question)]
idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", reply)]
idx = idx[len(previous_idx)-1]
if extensions:
reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):]
else:
reply = reply[idx + 1 + len(f"{current}:"):]
if check:
reply = reply.split('\n')[0].strip()
else:
idx = reply.find(f"\n{other}:")
if idx != -1:
reply = reply[:idx]
next_character_found = True
reply = clean_chat_message(reply)
# Detect if something like "\nYo" is generated just before
# "\nYou:" is completed
tmp = f"\n{other}:"
for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]:
substring_found = True
return reply, next_character_found, substring_found
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
return text, visible_text
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
shared.stop_everything = False
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
if shared.args.picture and picture is not None:
text, visible_text = generate_chat_picture(picture, name1, name2)
else:
visible_text = text
if shared.args.chat:
visible_text = visible_text.replace('\n', '<br>')
text = apply_extensions(text, "input")
question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size)
eos_token = '\n' if check else None
first = True
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
visible_reply = apply_extensions(reply, "output")
if shared.args.chat:
visible_reply = visible_reply.replace('\n', '<br>')
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
if shared.stop_everything:
return shared.history['visible']
if first:
first = False
shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', ''])
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
if not substring_found:
yield shared.history['visible']
if next_character_found:
break
yield shared.history['visible']
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True)
eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
if not substring_found:
yield reply
if next_character_found:
break
yield reply
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield generate_chat_html(_history, name1, name2, shared.character)
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
if shared.character is not None and len(shared.history['visible']) == 1:
if shared.args.cai_chat:
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
yield shared.history['visible']
else:
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
yield shared.history['visible']
def remove_last_message(name1, name2):
if not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop()
shared.history['internal'].pop()
else:
last = ['', '']
if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
else:
return shared.history['visible'], last[0]
def send_last_reply_to_input():
if len(shared.history['internal']) > 0:
return shared.history['internal'][-1][1]
else:
return ''
def replace_last_reply(text, name1, name2):
if len(shared.history['visible']) > 0:
if shared.args.cai_chat:
shared.history['visible'][-1][1] = text
else:
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
shared.history['internal'][-1][1] = apply_extensions(text, "input")
if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return shared.history['visible']
def clear_html():
return generate_chat_html([], "", "", shared.character)
def clear_chat_log(name1, name2):
if shared.character != 'None':
for i in range(len(shared.history['internal'])):
if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
shared.history['internal'] = shared.history['internal'][:i+1]
break
else:
shared.history['internal'] = []
shared.history['visible'] = []
if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return shared.history['visible']
def redraw_html(name1, name2):
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
def tokenize_dialogue(dialogue, name1, name2):
_history = []
dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('<start>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
if len(idx) == 0:
return _history
messages = []
for i in range(len(idx)-1):
messages.append(dialogue[idx[i]:idx[i+1]].strip())
messages.append(dialogue[idx[-1]:].strip())
entry = ['', '']
for i in messages:
if i.startswith(f'{name1}:'):
entry[0] = i[len(f'{name1}:'):].strip()
elif i.startswith(f'{name2}:'):
entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
_history.append(entry)
entry = ['', '']
print(f"\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
for row in _history:
for column in row:
print("\n")
for line in column.strip().split('\n'):
print("| "+line+"\n")
print("|\n")
print("------------------------------")
return _history
def save_history(timestamp=True):
if timestamp:
fname = f"{shared.character or ''}{'_' if shared.character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
fname = f"{shared.character or ''}{'_' if shared.character else ''}persistent.json"
if not Path('logs').exists():
Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}')
def load_history(file, name1, name2):
file = file.decode('utf-8')
try:
j = json.loads(file)
if 'data' in j:
shared.history['internal'] = j['data']
if 'data_visible' in j:
shared.history['visible'] = j['data_visible']
else:
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
# Compatibility with Pygmalion AI's official web UI
elif 'chat' in j:
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
shared.history['visible'][0][0] = ''
else:
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
except:
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def load_character(_character, name1, name2):
context = ""
shared.history['internal'] = []
shared.history['visible'] = []
if _character != 'None':
shared.character = _character
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
name2 = data['char_name']
if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
if 'world_scenario' in data and data['world_scenario'] != '':
context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n"
if 'example_dialogue' in data and data['example_dialogue'] != '':
shared.history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
else:
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
shared.history['visible'] += [['', "Hello there!"]]
else:
shared.character = None
context = shared.settings['context_pygmalion']
name2 = shared.settings['name2_pygmalion']
if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
if shared.args.cai_chat:
return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return name2, context, shared.history['visible']
def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
data = json.loads(json_file)
outfile_name = data["char_name"]
i = 1
while Path(f'characters/{outfile_name}.json').exists():
outfile_name = f'{data["char_name"]}_{i:03d}'
i += 1
if tavern:
outfile_name = f'TavernAI-{outfile_name}'
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
f.write(json_file)
if img is not None:
img = Image.open(io.BytesIO(img))
img.save(Path(f'characters/{outfile_name}.png'))
print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name
def upload_tavern_character(img, name1, name2):
_img = Image.open(io.BytesIO(img))
_img.getexif()
decoded_string = base64.b64decode(_img.info['chara'])
_json = json.loads(decoded_string)
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
_json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name'])
return upload_character(json.dumps(_json), img, tavern=True)
def upload_your_profile_picture(img):
img = Image.open(io.BytesIO(img))
img.save(Path(f'img_me.png'))
print(f'Profile picture saved to "img_me.png"')

64
modules/extensions.py Normal file
View File

@ -0,0 +1,64 @@
import extensions
import modules.shared as shared
import gradio as gr
extension_state = {}
available_extensions = []
def load_extensions():
global extension_state
for i,ext in enumerate(shared.args.extensions.split(',')):
if ext in available_extensions:
print(f'Loading the extension "{ext}"... ', end='')
ext_string = f"extensions.{ext}.script"
exec(f"import {ext_string}")
extension_state[ext] = [True, i]
print(f'Ok.')
def apply_extensions(text, typ):
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
if extension_state[ext][0] == True:
ext_string = f"extensions.{ext}.script"
if typ == "input" and hasattr(eval(ext_string), "input_modifier"):
text = eval(f"{ext_string}.input_modifier(text)")
elif typ == "output" and hasattr(eval(ext_string), "output_modifier"):
text = eval(f"{ext_string}.output_modifier(text)")
elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"):
text = eval(f"{ext_string}.bot_prefix_modifier(text)")
return text
def update_extensions_parameters(*kwargs):
i = 0
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
if extension_state[ext][0] == True:
params = eval(f"extensions.{ext}.script.params")
for param in params:
if len(kwargs) >= i+1:
params[param] = eval(f"kwargs[{i}]")
i += 1
def get_params(name):
return eval(f"extensions.{name}.script.params")
def create_extensions_block():
extensions_ui_elements = []
default_values = []
if not (shared.args.chat or shared.args.cai_chat):
gr.Markdown('## Extensions parameters')
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
if extension_state[ext][0] == True:
params = get_params(ext)
for param in params:
_id = f"{ext}-{param}"
default_value = shared.settings[_id] if _id in shared.settings else params[param]
default_values.append(default_value)
if type(params[param]) == str:
extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
elif type(params[param]) in [int, float]:
extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
elif type(params[param]) == bool:
extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
update_extensions_parameters(*default_values)
btn_extensions = gr.Button("Apply")
btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])

View File

@ -5,7 +5,6 @@ This is a library for formatting GPT-4chan and chat outputs as nice HTML.
''' '''
import base64 import base64
import copy
import os import os
import re import re
from io import BytesIO from io import BytesIO

150
modules/models.py Normal file
View File

@ -0,0 +1,150 @@
import json
import os
import time
import zipfile
from pathlib import Path
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import modules.shared as shared
transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen:
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
TorchDevice, TorchDisk, TorchMixedDevice,
get_opt_config)
if shared.args.deepspeed:
import deepspeed
from transformers.deepspeed import (HfDeepSpeedConfig,
is_deepspeed_zero3_enabled)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
def load_model(model_name):
print(f"Loading {model_name}...")
t0 = time.time()
# Default settings
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda()
# FlexGen
elif shared.args.flexgen:
gpu = TorchDevice("cuda:0")
cpu = TorchDevice("cpu")
disk = TorchDisk(shared.args.disk_cache_dir)
env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
# Offloading policy
policy = Policy(1, 1,
shared.args.percent[0], shared.args.percent[1],
shared.args.percent[2], shared.args.percent[3],
shared.args.percent[4], shared.args.percent[5],
overlap=True, sep_layer=True, pin_weight=True,
cpu_cache_compute=False, attn_sparsity=1.0,
compress_weight=shared.args.compress_weight,
comp_weight_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=0, symmetric=False),
compress_cache=False,
comp_cache_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
opt_config = get_opt_config(f"facebook/{shared.model_name}")
model = OptLM(opt_config, env, "models", policy)
model.init_all_weights()
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# Custom
else:
command = "AutoModelForCausalLM.from_pretrained"
params = ["low_cpu_mem_usage=True"]
if not shared.args.cpu and not torch.cuda.is_available():
print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n")
shared.args.cpu = True
if shared.args.cpu:
params.append("low_cpu_mem_usage=True")
params.append("torch_dtype=torch.float32")
else:
params.append("device_map='auto'")
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
if shared.args.gpu_memory:
params.append(f"max_memory={{0: '{shared.args.gpu_memory or '99'}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
elif not shared.args.load_in_8bit:
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
suggestion = round((total_mem-1000)/1000)*1000
if total_mem-suggestion < 800:
suggestion -= 1000
suggestion = int(round(suggestion/1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
if shared.args.disk:
params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
model = eval(command)
# Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/"))
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
def load_soft_prompt(name):
if name == 'None':
shared.soft_prompt = False
shared.soft_prompt_tensor = None
else:
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
zf.extract('tensor.npy')
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
print(f"\nLoading the softprompt \"{name}\".")
for field in j:
if field != 'name':
if type(j[field]) is list:
print(f"{field}: {', '.join(j[field])}")
else:
print(f"{field}: {j[field]}")
print()
tensor = np.load('tensor.npy')
Path('tensor.npy').unlink()
Path('meta.json').unlink()
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
shared.soft_prompt = True
shared.soft_prompt_tensor = tensor
return name

62
modules/shared.py Normal file
View File

@ -0,0 +1,62 @@
import argparse
model = None
tokenizer = None
model_name = ""
soft_prompt_tensor = None
soft_prompt = False
# Chat variables
history = {'internal': [], 'visible': []}
character = 'None'
stop_everything = False
settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
'max_new_tokens_max': 2000,
'preset': 'NovelAI-Sphinx Moth',
'name1': 'Person 1',
'name2': 'Person 2',
'context': 'This is a conversation between two people.',
'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'stop_at_newline': True,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048,
'preset_pygmalion': 'Pygmalion',
'name1_pygmalion': 'You',
'name2_pygmalion': 'Kawaii',
'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
'stop_at_newline_pygmalion': False,
}
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
parser.add_argument('--percent', nargs="+", type=int, default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
args = parser.parse_args()

View File

@ -8,6 +8,7 @@ https://github.com/PygmalionAI/gradio-ui/
import torch import torch
import transformers import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor, def __init__(self, sentinel_token_ids: torch.LongTensor,

178
modules/text_generation.py Normal file
View File

@ -0,0 +1,178 @@
import re
import time
import numpy as np
import torch
import transformers
from tqdm import tqdm
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
def get_max_prompt_length(tokens):
max_length = 2048-tokens
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
if shared.args.cpu or shared.args.flexgen:
return input_ids
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
else:
return input_ids.cuda()
def decode(output_ids):
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
s = re.sub("--- [0-9]*\n *\n---", "---", s)
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
s = s.replace(r'\(', r'$')
s = s.replace(r'\)', r'$')
s = s.replace(r'$$', r'$')
s = re.sub(r'\n', r'\n\n', s)
s = re.sub(r"\n{3,}", "\n\n", s)
return s
def formatted_outputs(reply, model_name):
if not (shared.args.chat or shared.args.cai_chat):
if shared.model_name.lower().startswith('galactica'):
reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply)
elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
else:
return reply
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
original_question = question
if not (shared.args.chat or shared.args.cai_chat):
question = apply_extensions(question, "input")
if shared.args.verbose:
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, tokens)
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
if not shared.args.flexgen:
n = shared.tokenizer.eos_token_id if eos_token is None else shared.tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
else:
n = shared.tokenizer(eos_token).input_ids[0] if eos_token else None
if stopping_string is not None:
# The stopping_criteria code below was copied from
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([
_SentinelTokenStoppingCriteria(
sentinel_token_ids=t,
starting_idx=len(input_ids[0])
)
])
else:
stopping_criteria_list = None
if not shared.args.flexgen:
generate_params = [
f"eos_token_id={n}",
f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}",
f"temperature={temperature}",
f"top_p={top_p}",
f"typical_p={typical_p}",
f"repetition_penalty={repetition_penalty}",
f"top_k={top_k}",
f"min_length={min_length if shared.args.no_stream else 0}",
f"no_repeat_ngram_size={no_repeat_ngram_size}",
f"num_beams={num_beams}",
f"penalty_alpha={penalty_alpha}",
f"length_penalty={length_penalty}",
f"early_stopping={early_stopping}",
]
else:
generate_params = [
f"do_sample={do_sample}",
f"temperature={temperature}",
f"stop={n}",
]
if shared.args.deepspeed:
generate_params.append("synced_gpus=True")
if shared.args.no_stream:
generate_params.append(f"max_new_tokens=tokens")
else:
generate_params.append(f"max_new_tokens=8")
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids")
else:
generate_params.insert(0, "input_ids")
# Generate the entire reply at once
if shared.args.no_stream:
t0 = time.time()
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
# Generate the reply 8 tokens at a time
else:
yield formatted_outputs(original_question, shared.model_name)
for i in tqdm(range(tokens//8+1)):
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
if not shared.args.flexgen:
input_ids = torch.reshape(output, (1, output.shape[0]))
else:
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
if output[-1] == n:
break

1025
server.py

File diff suppressed because it is too large Load Diff