mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Merge branch 'main' into UsamaKenway-main
This commit is contained in:
commit
c6e9ba20a4
@ -215,6 +215,8 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||||
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
||||||
|
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
|
||||||
|
| `--sdp-attention` | Use torch 2.0's sdp attention. |
|
||||||
|
|
||||||
#### llama.cpp
|
#### llama.cpp
|
||||||
|
|
||||||
|
@ -22,10 +22,10 @@ server = "127.0.0.1"
|
|||||||
params = {
|
params = {
|
||||||
'max_new_tokens': 200,
|
'max_new_tokens': 200,
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'temperature': 0.5,
|
'temperature': 0.72,
|
||||||
'top_p': 0.9,
|
'top_p': 0.73,
|
||||||
'typical_p': 1,
|
'typical_p': 1,
|
||||||
'repetition_penalty': 1.05,
|
'repetition_penalty': 1.1,
|
||||||
'encoder_repetition_penalty': 1.0,
|
'encoder_repetition_penalty': 1.0,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
|
@ -19,6 +19,7 @@ import requests
|
|||||||
import tqdm
|
import tqdm
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_file(url, output_folder):
|
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
|
||||||
output_path = output_folder / filename
|
|
||||||
if output_path.exists() and not args.clean:
|
|
||||||
# Check if the file has already been downloaded completely
|
|
||||||
r = requests.get(url, stream=True)
|
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
|
||||||
if output_path.stat().st_size >= total_size:
|
|
||||||
return
|
|
||||||
# Otherwise, resume the download from where it left off
|
|
||||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
|
||||||
mode = 'ab'
|
|
||||||
else:
|
|
||||||
headers = {}
|
|
||||||
mode = 'wb'
|
|
||||||
|
|
||||||
r = requests.get(url, stream=True, headers=headers)
|
|
||||||
with open(output_path, mode) as f:
|
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
|
||||||
block_size = 1024
|
|
||||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
|
||||||
for data in r.iter_content(block_size):
|
|
||||||
t.update(len(data))
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_branch_name(branch_name):
|
|
||||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
||||||
if pattern.match(branch_name):
|
|
||||||
return branch_name
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
|
||||||
|
|
||||||
|
|
||||||
def select_model_from_default_options():
|
def select_model_from_default_options():
|
||||||
models = {
|
models = {
|
||||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||||
@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped
|
|||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
|
|
||||||
def get_download_links_from_huggingface(model, branch):
|
def sanitize_model_and_branch_names(model, branch):
|
||||||
|
if model[-1] == '/':
|
||||||
|
model = model[:-1]
|
||||||
|
if branch is None:
|
||||||
|
branch = "main"
|
||||||
|
else:
|
||||||
|
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||||
|
if not pattern.match(branch):
|
||||||
|
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||||
|
|
||||||
|
return model, branch
|
||||||
|
|
||||||
|
|
||||||
|
def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||||
base = "https://huggingface.co"
|
base = "https://huggingface.co"
|
||||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||||
cursor = b""
|
cursor = b""
|
||||||
@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||||
classifications.append('text')
|
classifications.append('text')
|
||||||
continue
|
continue
|
||||||
if not args.text_only:
|
if not text_only:
|
||||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
has_safetensors = True
|
has_safetensors = True
|
||||||
@ -177,80 +157,114 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
return links, sha256, is_lora
|
return links, sha256, is_lora
|
||||||
|
|
||||||
|
|
||||||
def download_files(file_list, output_folder, num_threads=8):
|
def get_output_folder(model, branch, is_lora, base_folder=None):
|
||||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
if base_folder is None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
model = args.MODEL
|
|
||||||
branch = args.branch
|
|
||||||
if model is None:
|
|
||||||
model, branch = select_model_from_default_options()
|
|
||||||
else:
|
|
||||||
if model[-1] == '/':
|
|
||||||
model = model[:-1]
|
|
||||||
branch = args.branch
|
|
||||||
if branch is None:
|
|
||||||
branch = "main"
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
branch = sanitize_branch_name(branch)
|
|
||||||
except ValueError as err_branch:
|
|
||||||
print(f"Error: {err_branch}")
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
|
|
||||||
|
|
||||||
if args.output is not None:
|
|
||||||
base_folder = args.output
|
|
||||||
else:
|
|
||||||
base_folder = 'models' if not is_lora else 'loras'
|
base_folder = 'models' if not is_lora else 'loras'
|
||||||
|
|
||||||
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
||||||
if branch != 'main':
|
if branch != 'main':
|
||||||
output_folder += f'_{branch}'
|
output_folder += f'_{branch}'
|
||||||
output_folder = Path(base_folder) / output_folder
|
output_folder = Path(base_folder) / output_folder
|
||||||
|
return output_folder
|
||||||
|
|
||||||
|
|
||||||
|
def get_single_file(url, output_folder, start_from_scratch=False):
|
||||||
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
|
output_path = output_folder / filename
|
||||||
|
if output_path.exists() and not start_from_scratch:
|
||||||
|
# Check if the file has already been downloaded completely
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
if output_path.stat().st_size >= total_size:
|
||||||
|
return
|
||||||
|
# Otherwise, resume the download from where it left off
|
||||||
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
|
mode = 'ab'
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
|
mode = 'wb'
|
||||||
|
|
||||||
|
r = requests.get(url, stream=True, headers=headers)
|
||||||
|
with open(output_path, mode) as f:
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
block_size = 1024
|
||||||
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
|
for data in r.iter_content(block_size):
|
||||||
|
t.update(len(data))
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
|
||||||
|
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
|
||||||
|
# Creating the folder and writing the metadata
|
||||||
|
if not output_folder.exists():
|
||||||
|
output_folder.mkdir()
|
||||||
|
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
||||||
|
f.write(f'url: https://huggingface.co/{model}\n')
|
||||||
|
f.write(f'branch: {branch}\n')
|
||||||
|
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
||||||
|
sha256_str = ''
|
||||||
|
for i in range(len(sha256)):
|
||||||
|
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
||||||
|
if sha256_str != '':
|
||||||
|
f.write(f'sha256sum:\n{sha256_str}')
|
||||||
|
|
||||||
|
# Downloading the files
|
||||||
|
print(f"Downloading the model to {output_folder}")
|
||||||
|
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_files(model, branch, links, sha256, output_folder):
|
||||||
|
# Validate the checksums
|
||||||
|
validated = True
|
||||||
|
for i in range(len(sha256)):
|
||||||
|
fpath = (output_folder / sha256[i][0])
|
||||||
|
|
||||||
|
if not fpath.exists():
|
||||||
|
print(f"The following file is missing: {fpath}")
|
||||||
|
validated = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(output_folder / sha256[i][0], "rb") as f:
|
||||||
|
bytes = f.read()
|
||||||
|
file_hash = hashlib.sha256(bytes).hexdigest()
|
||||||
|
if file_hash != sha256[i][1]:
|
||||||
|
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
|
||||||
|
validated = False
|
||||||
|
else:
|
||||||
|
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
|
||||||
|
|
||||||
|
if validated:
|
||||||
|
print('[+] Validated checksums of all model files!')
|
||||||
|
else:
|
||||||
|
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
branch = args.branch
|
||||||
|
model = args.MODEL
|
||||||
|
if model is None:
|
||||||
|
model, branch = select_model_from_default_options()
|
||||||
|
|
||||||
|
# Cleaning up the model/branch names
|
||||||
|
try:
|
||||||
|
model, branch = sanitize_model_and_branch_names(model, branch)
|
||||||
|
except ValueError as err_branch:
|
||||||
|
print(f"Error: {err_branch}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# Getting the download links from Hugging Face
|
||||||
|
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
|
||||||
|
|
||||||
|
# Getting the output folder
|
||||||
|
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
|
||||||
|
|
||||||
if args.check:
|
if args.check:
|
||||||
# Validate the checksums
|
# Check previously downloaded files
|
||||||
validated = True
|
check_model_files(model, branch, links, sha256, output_folder)
|
||||||
for i in range(len(sha256)):
|
|
||||||
fpath = (output_folder / sha256[i][0])
|
|
||||||
|
|
||||||
if not fpath.exists():
|
|
||||||
print(f"The following file is missing: {fpath}")
|
|
||||||
validated = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
with open(output_folder / sha256[i][0], "rb") as f:
|
|
||||||
bytes = f.read()
|
|
||||||
file_hash = hashlib.sha256(bytes).hexdigest()
|
|
||||||
if file_hash != sha256[i][1]:
|
|
||||||
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
|
|
||||||
validated = False
|
|
||||||
else:
|
|
||||||
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
|
|
||||||
|
|
||||||
if validated:
|
|
||||||
print('[+] Validated checksums of all model files!')
|
|
||||||
else:
|
|
||||||
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Download files
|
||||||
# Creating the folder and writing the metadata
|
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
|
||||||
if not output_folder.exists():
|
|
||||||
output_folder.mkdir()
|
|
||||||
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
|
||||||
f.write(f'url: https://huggingface.co/{model}\n')
|
|
||||||
f.write(f'branch: {branch}\n')
|
|
||||||
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
|
||||||
sha256_str = ''
|
|
||||||
for i in range(len(sha256)):
|
|
||||||
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
|
||||||
if sha256_str != '':
|
|
||||||
f.write(f'sha256sum:\n{sha256_str}')
|
|
||||||
|
|
||||||
# Downloading the files
|
|
||||||
print(f"Downloading the model to {output_folder}")
|
|
||||||
download_files(links, output_folder, args.threads)
|
|
||||||
|
@ -1,8 +1,23 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import os
|
||||||
|
|
||||||
|
# get the current directory of the script
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# check if the bias_options.txt file exists, if not, create it
|
||||||
|
bias_file = os.path.join(current_dir, "bias_options.txt")
|
||||||
|
if not os.path.isfile(bias_file):
|
||||||
|
with open(bias_file, "w") as f:
|
||||||
|
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
|
||||||
|
|
||||||
|
# read bias options from the text file
|
||||||
|
with open(bias_file, "r") as f:
|
||||||
|
bias_options = [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"activate": True,
|
"activate": True,
|
||||||
"bias string": " *I am so happy*",
|
"bias string": " *I am so happy*",
|
||||||
|
"use custom string": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +26,6 @@ def input_modifier(string):
|
|||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
they are fed into the model.
|
they are fed into the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +33,6 @@ def output_modifier(string):
|
|||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
@ -29,9 +42,11 @@ def bot_prefix_modifier(string):
|
|||||||
the prefix text for the Bot and can be used to bias its
|
the prefix text for the Bot and can be used to bias its
|
||||||
behavior.
|
behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if params['activate']:
|
if params['activate']:
|
||||||
return f'{string} {params["bias string"].strip()} '
|
if params['use custom string']:
|
||||||
|
return f'{string} {params["custom string"].strip()} '
|
||||||
|
else:
|
||||||
|
return f'{string} {params["bias string"].strip()} '
|
||||||
else:
|
else:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
|
|||||||
def ui():
|
def ui():
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||||
string = gr.Textbox(value=params["bias string"], label='Character bias')
|
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
|
||||||
|
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
|
||||||
|
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
string.change(lambda x: params.update({"bias string": x}), string, None)
|
def update_bias_string(x):
|
||||||
|
if x:
|
||||||
|
params.update({"bias string": x})
|
||||||
|
else:
|
||||||
|
params.update({"bias string": dropdown_string.get()})
|
||||||
|
return x
|
||||||
|
|
||||||
|
def update_custom_string(x):
|
||||||
|
params.update({"custom string": x})
|
||||||
|
|
||||||
|
dropdown_string.change(update_bias_string, dropdown_string, None)
|
||||||
|
custom_string.change(update_custom_string, custom_string, None)
|
||||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||||
|
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
|
||||||
|
|
||||||
|
# Group elements together depending on the selected option
|
||||||
|
def bias_string_group():
|
||||||
|
if use_custom_string.value:
|
||||||
|
return gr.Group([use_custom_string, custom_string])
|
||||||
|
else:
|
||||||
|
return dropdown_string
|
||||||
|
@ -100,10 +100,10 @@ def load_quantized(model_name):
|
|||||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||||
pt_path = None
|
pt_path = None
|
||||||
|
|
||||||
if len(found_pts) == 1:
|
if len(found_pts) > 0:
|
||||||
pt_path = found_pts[0]
|
pt_path = found_pts[-1]
|
||||||
elif len(found_safetensors) == 1:
|
elif len(found_safetensors) > 0:
|
||||||
pt_path = found_safetensors[0]
|
pt_path = found_safetensors[-1]
|
||||||
else:
|
else:
|
||||||
if path_to_model.name.lower().startswith('llama-7b'):
|
if path_to_model.name.lower().startswith('llama-7b'):
|
||||||
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
||||||
@ -119,13 +119,14 @@ def load_quantized(model_name):
|
|||||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
print(f"Found {path}")
|
|
||||||
pt_path = path
|
pt_path = path
|
||||||
break
|
break
|
||||||
|
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||||
exit()
|
exit()
|
||||||
|
else:
|
||||||
|
print(f"Found the following quantized model: {pt_path}")
|
||||||
|
|
||||||
# qwopqwop200's offload
|
# qwopqwop200's offload
|
||||||
if model_type == 'llama' and shared.args.pre_layer:
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
|
@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
|
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
rows = [f"{context.strip()}\n"]
|
rows = [f"{context.strip()}\n"]
|
||||||
|
|
||||||
@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
|
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(shared.history['internal']) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
if _continue and i == len(shared.history['internal']) - 1:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||||
|
else:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||||
string = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
||||||
@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
if impersonate:
|
if impersonate:
|
||||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||||
limit = 2
|
limit = 2
|
||||||
|
elif _continue:
|
||||||
|
limit = 3
|
||||||
else:
|
else:
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
user_input = fix_newlines(user_input)
|
user_input = fix_newlines(user_input)
|
||||||
@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
||||||
if mode == 'instruct':
|
if mode == 'instruct':
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||||
else:
|
else:
|
||||||
@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
|
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||||
just_started = True
|
just_started = True
|
||||||
name1_original = name1
|
name1_original = name1
|
||||||
visible_text = custom_generate_chat_prompt = None
|
visible_text = custom_generate_chat_prompt = None
|
||||||
@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
if visible_text is None:
|
if visible_text is None:
|
||||||
visible_text = text
|
visible_text = text
|
||||||
text = apply_extensions(text, "input")
|
if not _continue:
|
||||||
|
text = apply_extensions(text, "input")
|
||||||
|
|
||||||
# Generating the prompt
|
# Generating the prompt
|
||||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
kwargs = {
|
||||||
|
'end_of_turn': end_of_turn,
|
||||||
|
'is_instruct': mode == 'instruct',
|
||||||
|
'_continue': _continue
|
||||||
|
}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
if not regenerate:
|
if not any((regenerate, _continue)):
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
@ -154,11 +166,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
return shared.history['visible']
|
return shared.history['visible']
|
||||||
if just_started:
|
if just_started:
|
||||||
just_started = False
|
just_started = False
|
||||||
shared.history['internal'].append(['', ''])
|
if not _continue:
|
||||||
shared.history['visible'].append(['', ''])
|
shared.history['internal'].append(['', ''])
|
||||||
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
shared.history['internal'][-1] = [text, reply]
|
if _continue:
|
||||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
|
||||||
|
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||||
|
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||||
|
else:
|
||||||
|
shared.history['internal'][-1] = [text, reply]
|
||||||
|
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
@ -220,6 +238,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
|
|||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
else:
|
||||||
|
# Yield ' ...'
|
||||||
|
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
||||||
|
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2, mode):
|
def remove_last_message(name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
last = shared.history['visible'].pop()
|
last = shared.history['visible'].pop()
|
||||||
@ -257,6 +285,9 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
|
# Save cleared logs
|
||||||
|
save_history(timestamp=False)
|
||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
@ -406,9 +437,14 @@ def load_character(character, name1, name2, mode):
|
|||||||
|
|
||||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||||
elif greeting != "":
|
else:
|
||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
# Insert greeting if it exists
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
if greeting != "":
|
||||||
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
|
# Create .json log files since they don't already exist
|
||||||
|
save_history(timestamp=False)
|
||||||
|
|
||||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||||
|
|
||||||
|
176
modules/llama_attn_hijack.py
Normal file
176
modules/llama_attn_hijack.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
|
if shared.args.xformers:
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
|
except Exception:
|
||||||
|
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_llama_attention():
|
||||||
|
if shared.args.xformers:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
|
print("Replaced attention with xformers_attention")
|
||||||
|
elif shared.args.sdp_attention:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||||
|
print("Replaced attention with sdp_attention")
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
dtype = query_states.dtype
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
|
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
|
else:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def sdp_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
|||||||
BitsAndBytesConfig, LlamaTokenizer)
|
BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules import llama_attn_hijack
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -169,11 +170,23 @@ def load_model(model_name):
|
|||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
|
# Hijack attention with xformers
|
||||||
|
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||||
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
elif type(model) is transformers.LlamaForCausalLM:
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
|
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||||
|
# For some people this fixes things, for others it causes an error.
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
@ -98,6 +98,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
|
|||||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
||||||
|
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
|
||||||
|
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
|
||||||
|
|
||||||
# llama.cpp
|
# llama.cpp
|
||||||
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
accelerate==0.18.0
|
accelerate==0.18.0
|
||||||
bitsandbytes==0.37.2
|
|
||||||
datasets
|
datasets
|
||||||
flexgen==0.1.7
|
flexgen==0.1.7
|
||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
@ -14,3 +13,6 @@ sentencepiece
|
|||||||
pyyaml
|
pyyaml
|
||||||
tqdm
|
tqdm
|
||||||
git+https://github.com/huggingface/transformers
|
git+https://github.com/huggingface/transformers
|
||||||
|
bitsandbytes==0.37.2; platform_system != "Windows"
|
||||||
|
llama-cpp-python==0.1.30; platform_system != "Windows"
|
||||||
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
|
41
server.py
41
server.py
@ -394,8 +394,9 @@ def create_interface():
|
|||||||
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
|
||||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||||
|
shared.gradio['Continue'] = gr.Button('Continue')
|
||||||
|
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||||
@ -467,53 +468,57 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.gradio['Replace last reply'].click(
|
shared.gradio['Replace last reply'].click(
|
||||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Clear history-confirm'].click(
|
shared.gradio['Clear history-confirm'].click(
|
||||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||||
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
|
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
|
||||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(
|
shared.gradio['Stop'].click(
|
||||||
stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['Chat mode'].change(
|
shared.gradio['Chat mode'].change(
|
||||||
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
|
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
|
||||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['Instruction templates'].change(
|
shared.gradio['Instruction templates'].change(
|
||||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['upload_chat_history'].upload(
|
shared.gradio['upload_chat_history'].upload(
|
||||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []).then(
|
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
shared.gradio['download_button'].click(chat.save_history, inputs=None, outputs=[shared.gradio['download']])
|
||||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||||
@ -521,7 +526,7 @@ def create_interface():
|
|||||||
|
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
@ -555,7 +560,7 @@ def create_interface():
|
|||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -589,7 +594,7 @@ def create_interface():
|
|||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
with gr.Tab("Model", elem_id="model-tab"):
|
with gr.Tab("Model", elem_id="model-tab"):
|
||||||
|
Loading…
Reference in New Issue
Block a user