Merge branch 'main' into UsamaKenway-main

This commit is contained in:
oobabooga 2023-04-10 11:14:03 -03:00
commit c6e9ba20a4
11 changed files with 436 additions and 149 deletions

View File

@ -215,6 +215,8 @@ Optionally, you can use the following command-line flags:
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
| `--sdp-attention` | Use torch 2.0's sdp attention. |
#### llama.cpp

View File

@ -22,10 +22,10 @@ server = "127.0.0.1"
params = {
'max_new_tokens': 200,
'do_sample': True,
'temperature': 0.5,
'top_p': 0.9,
'temperature': 0.72,
'top_p': 0.73,
'typical_p': 1,
'repetition_penalty': 1.05,
'repetition_penalty': 1.1,
'encoder_repetition_penalty': 1.0,
'top_k': 0,
'min_length': 0,

View File

@ -19,6 +19,7 @@ import requests
import tqdm
from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum
args = parser.parse_args()
def get_file(url, output_folder):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not args.clean:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
else:
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name):
return branch_name
else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
def select_model_from_default_options():
models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped
return model, branch
def get_download_links_from_huggingface(model, branch):
def sanitize_model_and_branch_names(model, branch):
if model[-1] == '/':
model = model[:-1]
if branch is None:
branch = "main"
else:
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if not pattern.match(branch):
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
return model, branch
def get_download_links_from_huggingface(model, branch, text_only=False):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""
@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch):
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
continue
if not args.text_only:
if not text_only:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
if is_safetensors:
has_safetensors = True
@ -177,41 +157,67 @@ def get_download_links_from_huggingface(model, branch):
return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
if __name__ == '__main__':
model = args.MODEL
branch = args.branch
if model is None:
model, branch = select_model_from_default_options()
else:
if model[-1] == '/':
model = model[:-1]
branch = args.branch
if branch is None:
branch = "main"
else:
try:
branch = sanitize_branch_name(branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
if args.output is not None:
base_folder = args.output
else:
def get_output_folder(model, branch, is_lora, base_folder=None):
if base_folder is None:
base_folder = 'models' if not is_lora else 'loras'
output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main':
output_folder += f'_{branch}'
output_folder = Path(base_folder) / output_folder
return output_folder
if args.check:
def get_single_file(url, output_folder, start_from_scratch=False):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
else:
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
# Creating the folder and writing the metadata
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
def check_model_files(model, branch, links, sha256, output_folder):
# Validate the checksums
validated = True
for i in range(len(sha256)):
@ -236,21 +242,29 @@ if __name__ == '__main__':
else:
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
if __name__ == '__main__':
branch = args.branch
model = args.MODEL
if model is None:
model, branch = select_model_from_default_options()
# Cleaning up the model/branch names
try:
model, branch = sanitize_model_and_branch_names(model, branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
# Getting the download links from Hugging Face
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
# Getting the output folder
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
if args.check:
# Check previously downloaded files
check_model_files(model, branch, links, sha256, output_folder)
else:
# Creating the folder and writing the metadata
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
download_files(links, output_folder, args.threads)
# Download files
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)

View File

@ -1,8 +1,23 @@
import gradio as gr
import os
# get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))
# check if the bias_options.txt file exists, if not, create it
bias_file = os.path.join(current_dir, "bias_options.txt")
if not os.path.isfile(bias_file):
with open(bias_file, "w") as f:
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
# read bias options from the text file
with open(bias_file, "r") as f:
bias_options = [line.strip() for line in f.readlines()]
params = {
"activate": True,
"bias string": " *I am so happy*",
"use custom string": False,
}
@ -11,7 +26,6 @@ def input_modifier(string):
This function is applied to your text inputs before
they are fed into the model.
"""
return string
@ -19,7 +33,6 @@ def output_modifier(string):
"""
This function is applied to the model outputs.
"""
return string
@ -29,8 +42,10 @@ def bot_prefix_modifier(string):
the prefix text for the Bot and can be used to bias its
behavior.
"""
if params['activate']:
if params['use custom string']:
return f'{string} {params["custom string"].strip()} '
else:
return f'{string} {params["bias string"].strip()} '
else:
return string
@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
def ui():
# Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
string = gr.Textbox(value=params["bias string"], label='Character bias')
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
# Event functions to update the parameters in the backend
string.change(lambda x: params.update({"bias string": x}), string, None)
def update_bias_string(x):
if x:
params.update({"bias string": x})
else:
params.update({"bias string": dropdown_string.get()})
return x
def update_custom_string(x):
params.update({"custom string": x})
dropdown_string.change(update_bias_string, dropdown_string, None)
custom_string.change(update_custom_string, custom_string, None)
activate.change(lambda x: params.update({"activate": x}), activate, None)
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
# Group elements together depending on the selected option
def bias_string_group():
if use_custom_string.value:
return gr.Group([use_custom_string, custom_string])
else:
return dropdown_string

View File

@ -100,10 +100,10 @@ def load_quantized(model_name):
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
if len(found_pts) == 1:
pt_path = found_pts[0]
elif len(found_safetensors) == 1:
pt_path = found_safetensors[0]
if len(found_pts) > 0:
pt_path = found_pts[-1]
elif len(found_safetensors) > 0:
pt_path = found_safetensors[-1]
else:
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.wbits}bit'
@ -119,13 +119,14 @@ def load_quantized(model_name):
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
print(f"Found {path}")
pt_path = path
break
if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit()
else:
print(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer:

View File

@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{context.strip()}\n"]
@ -39,6 +40,9 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
if impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2
elif _continue:
limit = 3
else:
# Adding the user message
user_input = fix_newlines(user_input)
@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
else:
@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
# Defining some variables
cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True
name1_original = name1
visible_text = custom_generate_chat_prompt = None
@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
if visible_text is None:
visible_text = text
if not _continue:
text = apply_extensions(text, "input")
# Generating the prompt
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
kwargs = {
'end_of_turn': end_of_turn,
'is_instruct': mode == 'instruct',
'_continue': _continue
}
if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
# Yield *Is typing...*
if not regenerate:
if not any((regenerate, _continue)):
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate
@ -154,9 +166,15 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
return shared.history['visible']
if just_started:
just_started = False
if not _continue:
shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', ''])
if _continue:
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
else:
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
if not shared.args.no_stream:
@ -220,6 +238,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
else:
# Yield ' ...'
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def remove_last_message(name1, name2, mode):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop()
@ -257,6 +285,9 @@ def clear_chat_log(name1, name2, greeting, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
# Save cleared logs
save_history(timestamp=False)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -406,10 +437,15 @@ def load_character(character, name1, name2, mode):
if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
elif greeting != "":
else:
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
# Create .json log files since they don't already exist
save_history(timestamp=False)
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)

View File

@ -0,0 +1,176 @@
import math
import sys
import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama
from typing import Optional
from typing import Tuple
import modules.shared as shared
if shared.args.xformers:
try:
import xformers.ops
except Exception:
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention():
if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
print("Replaced attention with xformers_attention")
elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
print("Replaced attention with sdp_attention")
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
#We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
dtype = query_states.dtype
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
#We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

View File

@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
from modules import llama_attn_hijack
transformers.logging.set_verbosity_error()
@ -169,11 +170,23 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Hijack attention with xformers
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
# Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
# Leaving this here until the LLaMA tokenizer gets figured out.
# For some people this fixes things, for others it causes an error.
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left'

View File

@ -98,6 +98,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
# llama.cpp
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')

View File

@ -1,5 +1,4 @@
accelerate==0.18.0
bitsandbytes==0.37.2
datasets
flexgen==0.1.7
gradio==3.24.1
@ -14,3 +13,6 @@ sentencepiece
pyyaml
tqdm
git+https://github.com/huggingface/transformers
bitsandbytes==0.37.2; platform_system != "Windows"
llama-cpp-python==0.1.30; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"

View File

@ -394,8 +394,9 @@ def create_interface():
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate')
shared.gradio['Continue'] = gr.Button('Continue')
shared.gradio['Impersonate'] = gr.Button('Impersonate')
with gr.Row():
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
@ -467,53 +468,57 @@ def create_interface():
gen_events.append(shared.gradio['Generate'].click(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
)
gen_events.append(shared.gradio['textbox'].submit(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
)
gen_events.append(shared.gradio['Regenerate'].click(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
)
gen_events.append(shared.gradio['Continue'].click(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
)
shared.gradio['Replace last reply'].click(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
shared.gradio['Clear history-confirm'].click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
shared.gradio['Stop'].click(
stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, reload_inputs, [shared.gradio['display']])
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Chat mode'].change(
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
chat.redraw_html, reload_inputs, [shared.gradio['display']])
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Instruction templates'].change(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.redraw_html, reload_inputs, [shared.gradio['display']])
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []).then(
chat.redraw_html, reload_inputs, [shared.gradio['display']])
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['download_button'].click(chat.save_history, inputs=None, outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
@ -521,7 +526,7 @@ def create_interface():
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
@ -555,7 +560,7 @@ def create_interface():
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else:
@ -589,7 +594,7 @@ def create_interface():
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Model", elem_id="model-tab"):