mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
Make the code more like PEP8 for readability (#862)
This commit is contained in:
parent
848c4edfd5
commit
ea6e77df72
@ -17,6 +17,7 @@ def random_hash():
|
|||||||
letters = string.ascii_lowercase + string.digits
|
letters = string.ascii_lowercase + string.digits
|
||||||
return ''.join(random.choice(letters) for i in range(9))
|
return ''.join(random.choice(letters) for i in range(9))
|
||||||
|
|
||||||
|
|
||||||
async def run(context):
|
async def run(context):
|
||||||
server = "127.0.0.1"
|
server = "127.0.0.1"
|
||||||
params = {
|
params = {
|
||||||
@ -69,6 +70,7 @@ async def run(context):
|
|||||||
|
|
||||||
prompt = "What I would like to say is the following: "
|
prompt = "What I would like to say is the following: "
|
||||||
|
|
||||||
|
|
||||||
async def get_result():
|
async def get_result():
|
||||||
async for response in run(prompt):
|
async for response in run(prompt):
|
||||||
# Print intermediate steps
|
# Print intermediate steps
|
||||||
|
@ -17,6 +17,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
|
|||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def disable_torch_init():
|
def disable_torch_init():
|
||||||
"""
|
"""
|
||||||
Disable the redundant torch default initialization to accelerate model creation.
|
Disable the redundant torch default initialization to accelerate model creation.
|
||||||
@ -31,12 +32,14 @@ def disable_torch_init():
|
|||||||
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
||||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
def restore_torch_init():
|
def restore_torch_init():
|
||||||
"""Rollback the change made by disable_torch_init."""
|
"""Rollback the change made by disable_torch_init."""
|
||||||
import torch
|
import torch
|
||||||
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
|
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
|
||||||
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
|
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
path = Path(args.MODEL)
|
path = Path(args.MODEL)
|
||||||
model_name = path.name
|
model_name = path.name
|
||||||
|
@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
|
|||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_file(url, output_folder):
|
def get_file(url, output_folder):
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
output_path = output_folder / filename
|
output_path = output_folder / filename
|
||||||
@ -54,6 +55,7 @@ def get_file(url, output_folder):
|
|||||||
t.update(len(data))
|
t.update(len(data))
|
||||||
f.write(data)
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
def sanitize_branch_name(branch_name):
|
def sanitize_branch_name(branch_name):
|
||||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||||
if pattern.match(branch_name):
|
if pattern.match(branch_name):
|
||||||
@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||||
|
|
||||||
|
|
||||||
def select_model_from_default_options():
|
def select_model_from_default_options():
|
||||||
models = {
|
models = {
|
||||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||||
@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped
|
|||||||
|
|
||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
|
|
||||||
def get_download_links_from_huggingface(model, branch):
|
def get_download_links_from_huggingface(model, branch):
|
||||||
base = "https://huggingface.co"
|
base = "https://huggingface.co"
|
||||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||||
@ -172,9 +176,11 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
|
|
||||||
return links, sha256, is_lora
|
return links, sha256, is_lora
|
||||||
|
|
||||||
|
|
||||||
def download_files(file_list, output_folder, num_threads=8):
|
def download_files(file_list, output_folder, num_threads=8):
|
||||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = args.MODEL
|
model = args.MODEL
|
||||||
branch = args.branch
|
branch = args.branch
|
||||||
|
@ -9,6 +9,7 @@ params = {
|
|||||||
'port': 5000,
|
'port': 5000,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path == '/api/v1/model':
|
if self.path == '/api/v1/model':
|
||||||
@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
prompt = body['prompt']
|
prompt = body['prompt']
|
||||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||||
|
|
||||||
max_context = body.get('max_context_length', 2048)
|
max_context = body.get('max_context_length', 2048)
|
||||||
|
|
||||||
@ -95,5 +96,6 @@ def run_server():
|
|||||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
Thread(target=run_server, daemon=True).start()
|
Thread(target=run_server, daemon=True).start()
|
||||||
|
@ -5,6 +5,7 @@ params = {
|
|||||||
"bias string": " *I am so happy*",
|
"bias string": " *I am so happy*",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -13,6 +14,7 @@ def input_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@ -20,6 +22,7 @@ def output_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
|
|||||||
behavior.
|
behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if params['activate'] == True:
|
if params['activate']:
|
||||||
return f'{string} {params["bias string"].strip()} '
|
return f'{string} {params["bias string"].strip()} '
|
||||||
else:
|
else:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||||
|
@ -22,6 +22,8 @@ if not shared.args.no_stream:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
# Check if the API is valid and refresh the UI accordingly.
|
# Check if the API is valid and refresh the UI accordingly.
|
||||||
|
|
||||||
|
|
||||||
def check_valid_api():
|
def check_valid_api():
|
||||||
|
|
||||||
global user, user_info, params
|
global user, user_info, params
|
||||||
@ -29,7 +31,7 @@ def check_valid_api():
|
|||||||
user = ElevenLabsUser(params['api_key'])
|
user = ElevenLabsUser(params['api_key'])
|
||||||
user_info = user._get_subscription_data()
|
user_info = user._get_subscription_data()
|
||||||
print('checking api')
|
print('checking api')
|
||||||
if params['activate'] == False:
|
if not params['activate']:
|
||||||
return gr.update(value='Disconnected')
|
return gr.update(value='Disconnected')
|
||||||
elif user_info is None:
|
elif user_info is None:
|
||||||
print('Incorrect API Key')
|
print('Incorrect API Key')
|
||||||
@ -39,6 +41,8 @@ def check_valid_api():
|
|||||||
return gr.update(value='Connected')
|
return gr.update(value='Connected')
|
||||||
|
|
||||||
# Once the API is verified, get the available voices and update the dropdown list
|
# Once the API is verified, get the available voices and update the dropdown list
|
||||||
|
|
||||||
|
|
||||||
def refresh_voices():
|
def refresh_voices():
|
||||||
|
|
||||||
global user, user_info
|
global user, user_info
|
||||||
@ -51,11 +55,13 @@ def refresh_voices():
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -64,6 +70,7 @@ def input_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@ -71,9 +78,9 @@ def output_modifier(string):
|
|||||||
|
|
||||||
global params, wav_idx, user, user_info
|
global params, wav_idx, user, user_info
|
||||||
|
|
||||||
if params['activate'] == False:
|
if not params['activate']:
|
||||||
return string
|
return string
|
||||||
elif user_info == None:
|
elif user_info is None:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
string = remove_surrounded_chars(string)
|
string = remove_surrounded_chars(string)
|
||||||
@ -94,6 +101,7 @@ def output_modifier(string):
|
|||||||
wav_idx += 1
|
wav_idx += 1
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
|
@ -7,6 +7,7 @@ params = {
|
|||||||
|
|
||||||
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -15,6 +16,7 @@ def input_modifier(string):
|
|||||||
|
|
||||||
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@ -22,6 +24,7 @@ def output_modifier(string):
|
|||||||
|
|
||||||
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
# Finding the language name from the language code to use as the default value
|
# Finding the language name from the language code to use as the default value
|
||||||
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
||||||
|
@ -4,12 +4,14 @@ import pandas as pd
|
|||||||
|
|
||||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_by_name(name):
|
def get_prompt_by_name(name):
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
return ''
|
return ''
|
||||||
else:
|
else:
|
||||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
choices = ['None'] + list(df['Prompt name'])
|
choices = ['None'] + list(df['Prompt name'])
|
||||||
|
@ -30,12 +30,15 @@ streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
|||||||
picture_response = False # specifies if the next model response should appear as a picture
|
picture_response = False # specifies if the next model response should appear as a picture
|
||||||
pic_id = 0
|
pic_id = 0
|
||||||
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@ -62,6 +65,8 @@ def input_modifier(string):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
# Get and save the Stable Diffusion-generated picture
|
# Get and save the Stable Diffusion-generated picture
|
||||||
|
|
||||||
|
|
||||||
def get_SD_pictures(description):
|
def get_SD_pictures(description):
|
||||||
|
|
||||||
global params, pic_id
|
global params, pic_id
|
||||||
@ -101,6 +106,8 @@ def get_SD_pictures(description):
|
|||||||
|
|
||||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||||
# and replace it with 'text' for the purposes of logging?
|
# and replace it with 'text' for the purposes of logging?
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@ -130,6 +137,7 @@ def output_modifier(string):
|
|||||||
shared.args.no_stream = streaming_state
|
shared.args.no_stream = streaming_state
|
||||||
return image + "\n" + text
|
return image + "\n" + text
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def force_pic():
|
def force_pic():
|
||||||
global picture_response
|
global picture_response
|
||||||
picture_response = True
|
picture_response = True
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
|
@ -17,11 +17,13 @@ input_hijack = {
|
|||||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||||
|
|
||||||
|
|
||||||
def caption_image(raw_image):
|
def caption_image(raw_image):
|
||||||
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||||
out = model.generate(**inputs, max_new_tokens=100)
|
out = model.generate(**inputs, max_new_tokens=100)
|
||||||
return processor.decode(out[0], skip_special_tokens=True)
|
return processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_picture(picture, name1, name2):
|
def generate_chat_picture(picture, name1, name2):
|
||||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
|
|||||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||||
return text, visible_text
|
return text, visible_text
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||||
|
|
||||||
|
@ -17,9 +17,11 @@ from quant import make_quant
|
|||||||
|
|
||||||
|
|
||||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||||
config = AutoConfig.from_pretrained(model)
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
torch.nn.init.kaiming_uniform_ = noop
|
torch.nn.init.kaiming_uniform_ = noop
|
||||||
torch.nn.init.uniform_ = noop
|
torch.nn.init.uniform_ = noop
|
||||||
torch.nn.init.normal_ = noop
|
torch.nn.init.normal_ = noop
|
||||||
@ -64,6 +66,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_quantized(model_name):
|
def load_quantized(model_name):
|
||||||
if not shared.args.model_type:
|
if not shared.args.model_type:
|
||||||
# Try to determine model type from model name
|
# Try to determine model type from model name
|
||||||
|
@ -13,6 +13,7 @@ def reload_model():
|
|||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_name):
|
def add_lora_to_model(lora_name):
|
||||||
|
|
||||||
# If a LoRA had been previously loaded, or if we want
|
# If a LoRA had been previously loaded, or if we want
|
||||||
|
@ -54,6 +54,7 @@ class RWKVModel:
|
|||||||
reply += token
|
reply += token
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
class RWKVTokenizer:
|
class RWKVTokenizer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
|
|||||||
for i in generate_reply(params[0], generate_params):
|
for i in generate_reply(params[0], generate_params):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
def create_apis():
|
def create_apis():
|
||||||
t1 = gr.Textbox(visible=False)
|
t1 = gr.Textbox(visible=False)
|
||||||
t2 = gr.Textbox(visible=False)
|
t2 = gr.Textbox(visible=False)
|
||||||
|
@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
def __init__(self, callback_func=None):
|
def __init__(self, callback_func=None):
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
|
|||||||
self.callback_func(input_ids[0])
|
self.callback_func(input_ids[0])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Iteratorize:
|
class Iteratorize:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -96,6 +98,7 @@ class Iteratorize:
|
|||||||
self.stop_now = True
|
self.stop_now = True
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
def clear_torch_cache():
|
def clear_torch_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
|
@ -23,7 +23,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
|
|
||||||
rows = [f"{context.strip()}\n"]
|
rows = [f"{context.strip()}\n"]
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
else:
|
else:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||||
next_character_found = False
|
next_character_found = False
|
||||||
|
|
||||||
@ -98,6 +98,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
reply = fix_newlines(reply)
|
reply = fix_newlines(reply)
|
||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||||
if mode == 'instruct':
|
if mode == 'instruct':
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||||
@ -113,7 +114,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
visible_text = None
|
visible_text = None
|
||||||
custom_generate_chat_prompt = None
|
custom_generate_chat_prompt = None
|
||||||
for extension, _ in extensions_module.iterator():
|
for extension, _ in extensions_module.iterator():
|
||||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
|
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||||
extension.input_hijack['state'] = False
|
extension.input_hijack['state'] = False
|
||||||
text, visible_text = extension.input_hijack['value']
|
text, visible_text = extension.input_hijack['value']
|
||||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||||
@ -167,6 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
|
|
||||||
|
|
||||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
if mode == 'instruct':
|
if mode == 'instruct':
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||||
@ -197,10 +199,12 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
yield chat_html_wrapper(history, name1, name2, mode)
|
yield chat_html_wrapper(history, name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
@ -213,6 +217,7 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
|
|||||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2, mode):
|
def remove_last_message(name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
last = shared.history['visible'].pop()
|
last = shared.history['visible'].pop()
|
||||||
@ -222,12 +227,14 @@ def remove_last_message(name1, name2, mode):
|
|||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
||||||
|
|
||||||
|
|
||||||
def send_last_reply_to_input():
|
def send_last_reply_to_input():
|
||||||
if len(shared.history['internal']) > 0:
|
if len(shared.history['internal']) > 0:
|
||||||
return shared.history['internal'][-1][1]
|
return shared.history['internal'][-1][1]
|
||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def replace_last_reply(text, name1, name2, mode):
|
def replace_last_reply(text, name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0:
|
if len(shared.history['visible']) > 0:
|
||||||
shared.history['visible'][-1][1] = text
|
shared.history['visible'][-1][1] = text
|
||||||
@ -235,9 +242,11 @@ def replace_last_reply(text, name1, name2, mode):
|
|||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def clear_html():
|
def clear_html():
|
||||||
return chat_html_wrapper([], "", "")
|
return chat_html_wrapper([], "", "")
|
||||||
|
|
||||||
|
|
||||||
def clear_chat_log(name1, name2, greeting, mode):
|
def clear_chat_log(name1, name2, greeting, mode):
|
||||||
shared.history['visible'] = []
|
shared.history['visible'] = []
|
||||||
shared.history['internal'] = []
|
shared.history['internal'] = []
|
||||||
@ -248,9 +257,11 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def redraw_html(name1, name2, mode):
|
def redraw_html(name1, name2, mode):
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
@ -288,6 +299,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def save_history(timestamp=True):
|
def save_history(timestamp=True):
|
||||||
if timestamp:
|
if timestamp:
|
||||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
@ -299,6 +311,7 @@ def save_history(timestamp=True):
|
|||||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||||
return Path(f'logs/{fname}')
|
return Path(f'logs/{fname}')
|
||||||
|
|
||||||
|
|
||||||
def load_history(file, name1, name2):
|
def load_history(file, name1, name2):
|
||||||
file = file.decode('utf-8')
|
file = file.decode('utf-8')
|
||||||
try:
|
try:
|
||||||
@ -323,10 +336,12 @@ def load_history(file, name1, name2):
|
|||||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||||
|
|
||||||
|
|
||||||
def replace_character_names(text, name1, name2):
|
def replace_character_names(text, name1, name2):
|
||||||
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
||||||
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
||||||
|
|
||||||
|
|
||||||
def build_pygmalion_style_context(data):
|
def build_pygmalion_style_context(data):
|
||||||
context = ""
|
context = ""
|
||||||
if 'char_persona' in data and data['char_persona'] != '':
|
if 'char_persona' in data and data['char_persona'] != '':
|
||||||
@ -336,6 +351,7 @@ def build_pygmalion_style_context(data):
|
|||||||
context = f"{context.strip()}\n<START>\n"
|
context = f"{context.strip()}\n<START>\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def generate_pfp_cache(character):
|
def generate_pfp_cache(character):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
@ -348,6 +364,7 @@ def generate_pfp_cache(character):
|
|||||||
return img
|
return img
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_character(character, name1, name2, mode):
|
def load_character(character, name1, name2, mode):
|
||||||
shared.character = character
|
shared.character = character
|
||||||
shared.history['internal'] = []
|
shared.history['internal'] = []
|
||||||
@ -404,9 +421,11 @@ def load_character(character, name1, name2, mode):
|
|||||||
|
|
||||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||||
|
|
||||||
|
|
||||||
def load_default_history(name1, name2):
|
def load_default_history(name1, name2):
|
||||||
load_character("None", name1, name2, "chat")
|
load_character("None", name1, name2, "chat")
|
||||||
|
|
||||||
|
|
||||||
def upload_character(json_file, img, tavern=False):
|
def upload_character(json_file, img, tavern=False):
|
||||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||||
data = json.loads(json_file)
|
data = json.loads(json_file)
|
||||||
@ -425,6 +444,7 @@ def upload_character(json_file, img, tavern=False):
|
|||||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||||
return outfile_name
|
return outfile_name
|
||||||
|
|
||||||
|
|
||||||
def upload_tavern_character(img, name1, name2):
|
def upload_tavern_character(img, name1, name2):
|
||||||
_img = Image.open(io.BytesIO(img))
|
_img = Image.open(io.BytesIO(img))
|
||||||
_img.getexif()
|
_img.getexif()
|
||||||
@ -433,12 +453,13 @@ def upload_tavern_character(img, name1, name2):
|
|||||||
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||||
return upload_character(json.dumps(_json), img, tavern=True)
|
return upload_character(json.dumps(_json), img, tavern=True)
|
||||||
|
|
||||||
|
|
||||||
def upload_your_profile_picture(img, name1, name2, mode):
|
def upload_your_profile_picture(img, name1, name2, mode):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
cache_folder.mkdir()
|
cache_folder.mkdir()
|
||||||
|
|
||||||
if img == None:
|
if img is None:
|
||||||
if Path("cache/pfp_me.png").exists():
|
if Path("cache/pfp_me.png").exists():
|
||||||
Path("cache/pfp_me.png").unlink()
|
Path("cache/pfp_me.png").unlink()
|
||||||
else:
|
else:
|
||||||
|
@ -9,6 +9,7 @@ state = {}
|
|||||||
available_extensions = []
|
available_extensions = []
|
||||||
setup_called = set()
|
setup_called = set()
|
||||||
|
|
||||||
|
|
||||||
def load_extensions():
|
def load_extensions():
|
||||||
global state
|
global state
|
||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
@ -23,12 +24,16 @@ def load_extensions():
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# This iterator returns the extensions in the order specified in the command-line
|
# This iterator returns the extensions in the order specified in the command-line
|
||||||
|
|
||||||
|
|
||||||
def iterator():
|
def iterator():
|
||||||
for name in sorted(state, key=lambda x: state[x][1]):
|
for name in sorted(state, key=lambda x: state[x][1]):
|
||||||
if state[name][0] == True:
|
if state[name][0] == True:
|
||||||
yield eval(f"extensions.{name}.script"), name
|
yield eval(f"extensions.{name}.script"), name
|
||||||
|
|
||||||
# Extension functions that map string -> string
|
# Extension functions that map string -> string
|
||||||
|
|
||||||
|
|
||||||
def apply_extensions(text, typ):
|
def apply_extensions(text, typ):
|
||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||||
@ -39,6 +44,7 @@ def apply_extensions(text, typ):
|
|||||||
text = extension.bot_prefix_modifier(text)
|
text = extension.bot_prefix_modifier(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def create_extensions_block():
|
def create_extensions_block():
|
||||||
global setup_called
|
global setup_called
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
|
|||||||
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
||||||
instruct_css = f.read()
|
instruct_css = f.read()
|
||||||
|
|
||||||
|
|
||||||
def fix_newlines(string):
|
def fix_newlines(string):
|
||||||
string = string.replace('\n', '\n\n')
|
string = string.replace('\n', '\n\n')
|
||||||
string = re.sub(r"\n{3,}", "\n\n", string)
|
string = re.sub(r"\n{3,}", "\n\n", string)
|
||||||
@ -31,6 +32,8 @@ def fix_newlines(string):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
# This could probably be generalized and improved
|
# This could probably be generalized and improved
|
||||||
|
|
||||||
|
|
||||||
def convert_to_markdown(string):
|
def convert_to_markdown(string):
|
||||||
string = string.replace('\\begin{code}', '```')
|
string = string.replace('\\begin{code}', '```')
|
||||||
string = string.replace('\\end{code}', '```')
|
string = string.replace('\\end{code}', '```')
|
||||||
@ -40,11 +43,13 @@ def convert_to_markdown(string):
|
|||||||
string = fix_newlines(string)
|
string = fix_newlines(string)
|
||||||
return markdown.markdown(string, extensions=['fenced_code'])
|
return markdown.markdown(string, extensions=['fenced_code'])
|
||||||
|
|
||||||
|
|
||||||
def generate_basic_html(string):
|
def generate_basic_html(string):
|
||||||
string = convert_to_markdown(string)
|
string = convert_to_markdown(string)
|
||||||
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def process_post(post, c):
|
def process_post(post, c):
|
||||||
t = post.split('\n')
|
t = post.split('\n')
|
||||||
number = t[0].split(' ')[1]
|
number = t[0].split(' ')[1]
|
||||||
@ -59,6 +64,7 @@ def process_post(post, c):
|
|||||||
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
def generate_4chan_html(f):
|
def generate_4chan_html(f):
|
||||||
posts = []
|
posts = []
|
||||||
post = ''
|
post = ''
|
||||||
@ -98,6 +104,7 @@ def generate_4chan_html(f):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def make_thumbnail(image):
|
def make_thumbnail(image):
|
||||||
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
||||||
if image.size[1] > 470:
|
if image.size[1] > 470:
|
||||||
@ -105,6 +112,7 @@ def make_thumbnail(image):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def get_image_cache(path):
|
def get_image_cache(path):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
@ -119,6 +127,7 @@ def get_image_cache(path):
|
|||||||
|
|
||||||
return image_cache[path][1]
|
return image_cache[path][1]
|
||||||
|
|
||||||
|
|
||||||
def generate_instruct_html(history):
|
def generate_instruct_html(history):
|
||||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
||||||
for i, _row in enumerate(history[::-1]):
|
for i, _row in enumerate(history[::-1]):
|
||||||
@ -151,6 +160,7 @@ def generate_instruct_html(history):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||||
|
|
||||||
@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
|||||||
output += "</div>"
|
output += "</div>"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_html(history, name1, name2):
|
def generate_chat_html(history, name1, name2):
|
||||||
return generate_cai_chat_html(history, name1, name2)
|
return generate_cai_chat_html(history, name1, name2)
|
||||||
|
|
||||||
|
|
||||||
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
||||||
if mode == "cai-chat":
|
if mode == "cai-chat":
|
||||||
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
||||||
|
@ -6,8 +6,6 @@ Documentation:
|
|||||||
https://abetlen.github.io/llama-cpp-python/
|
https://abetlen.github.io/llama-cpp-python/
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
@ -181,6 +181,7 @@ def load_model(model_name):
|
|||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def load_soft_prompt(name):
|
def load_soft_prompt(name):
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
shared.soft_prompt = False
|
shared.soft_prompt = False
|
||||||
|
@ -61,6 +61,7 @@ settings = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
return v
|
return v
|
||||||
@ -71,6 +72,7 @@ def str2bool(v):
|
|||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||||
|
|
||||||
# Basic settings
|
# Basic settings
|
||||||
@ -145,5 +147,6 @@ if args.cai_chat:
|
|||||||
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
||||||
args.chat = True
|
args.chat = True
|
||||||
|
|
||||||
|
|
||||||
def is_chat():
|
def is_chat():
|
||||||
return args.chat
|
return args.chat
|
||||||
|
@ -21,6 +21,7 @@ def get_max_prompt_length(tokens):
|
|||||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||||
return max_length
|
return max_length
|
||||||
|
|
||||||
|
|
||||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||||||
else:
|
else:
|
||||||
return input_ids.cuda()
|
return input_ids.cuda()
|
||||||
|
|
||||||
|
|
||||||
def decode(output_ids):
|
def decode(output_ids):
|
||||||
# Open Assistant relies on special tokens like <|endoftext|>
|
# Open Assistant relies on special tokens like <|endoftext|>
|
||||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||||
@ -53,6 +55,7 @@ def decode(output_ids):
|
|||||||
reply = reply.replace(r'<|endoftext|>', '')
|
reply = reply.replace(r'<|endoftext|>', '')
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def generate_softprompt_input_tensors(input_ids):
|
def generate_softprompt_input_tensors(input_ids):
|
||||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
@ -61,6 +64,8 @@ def generate_softprompt_input_tensors(input_ids):
|
|||||||
return inputs_embeds, filler_input_ids
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
|
|
||||||
|
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||||
@ -69,6 +74,8 @@ def fix_gpt4chan(s):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
# Fix the LaTeX equations in galactica
|
||||||
|
|
||||||
|
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
s = s.replace(r'\]', r'$')
|
s = s.replace(r'\]', r'$')
|
||||||
@ -79,6 +86,7 @@ def fix_galactica(s):
|
|||||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
if 'galactica' in model_name.lower():
|
if 'galactica' in model_name.lower():
|
||||||
@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name):
|
|||||||
else:
|
else:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def clear_torch_cache():
|
def clear_torch_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def set_manual_seed(seed):
|
def set_manual_seed(seed):
|
||||||
if seed != -1:
|
if seed != -1:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(generate_state['seed'])
|
set_manual_seed(generate_state['seed'])
|
||||||
|
@ -19,9 +19,11 @@ CURRENT_STEPS = 0
|
|||||||
MAX_STEPS = 0
|
MAX_STEPS = 0
|
||||||
CURRENT_GRADIENT_ACCUM = 1
|
CURRENT_GRADIENT_ACCUM = 1
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(path: str, ext: str):
|
def get_dataset(path: str, ext: str):
|
||||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def create_train_interface():
|
def create_train_interface():
|
||||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||||
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
|
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
|
||||||
@ -67,10 +69,12 @@ def create_train_interface():
|
|||||||
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||||
|
|
||||||
|
|
||||||
def do_interrupt():
|
def do_interrupt():
|
||||||
global WANT_INTERRUPT
|
global WANT_INTERRUPT
|
||||||
WANT_INTERRUPT = True
|
WANT_INTERRUPT = True
|
||||||
|
|
||||||
|
|
||||||
class Callbacks(transformers.TrainerCallback):
|
class Callbacks(transformers.TrainerCallback):
|
||||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
global CURRENT_STEPS, MAX_STEPS
|
global CURRENT_STEPS, MAX_STEPS
|
||||||
@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
|
|||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
global CURRENT_STEPS
|
global CURRENT_STEPS
|
||||||
CURRENT_STEPS += 1
|
CURRENT_STEPS += 1
|
||||||
@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
|
||||||
def clean_path(base_path: str, path: str):
|
def clean_path(base_path: str, path: str):
|
||||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||||
@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
|
|||||||
return path
|
return path
|
||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||||
@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
print("Training complete!")
|
print("Training complete!")
|
||||||
yield f"Done! LoRA saved to `{lora_name}`"
|
yield f"Done! LoRA saved to `{lora_name}`"
|
||||||
|
|
||||||
|
|
||||||
def split_chunks(arr, step):
|
def split_chunks(arr, step):
|
||||||
for i in range(0, len(arr), step):
|
for i in range(0, len(arr), step):
|
||||||
yield arr[i:i + step]
|
yield arr[i:i + step]
|
||||||
|
|
||||||
|
|
||||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||||
if '\n' not in chunk:
|
if '\n' not in chunk:
|
||||||
return chunk
|
return chunk
|
||||||
@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
|
|||||||
chunk = chunk[:last_newline]
|
chunk = chunk[:last_newline]
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float):
|
def format_time(seconds: float):
|
||||||
if seconds < 120:
|
if seconds < 120:
|
||||||
return f"`{seconds:.0f}` seconds"
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
|||||||
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||||
chat_js = f.read()
|
chat_js = f.read()
|
||||||
|
|
||||||
|
|
||||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "button"
|
return "button"
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_method()
|
refresh_method()
|
||||||
|
25
server.py
25
server.py
@ -34,15 +34,18 @@ if settings_file is not None:
|
|||||||
for item in new_settings:
|
for item in new_settings:
|
||||||
shared.settings[item] = new_settings[item]
|
shared.settings[item] = new_settings[item]
|
||||||
|
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
if shared.args.flexgen:
|
if shared.args.flexgen:
|
||||||
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
||||||
else:
|
else:
|
||||||
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_presets():
|
def get_available_presets():
|
||||||
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
|
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_prompts():
|
def get_available_prompts():
|
||||||
prompts = []
|
prompts = []
|
||||||
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
||||||
@ -50,10 +53,12 @@ def get_available_prompts():
|
|||||||
prompts += ['None']
|
prompts += ['None']
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
def get_available_characters():
|
def get_available_characters():
|
||||||
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_instruction_templates():
|
def get_available_instruction_templates():
|
||||||
path = "characters/instruction-following"
|
path = "characters/instruction-following"
|
||||||
paths = []
|
paths = []
|
||||||
@ -61,19 +66,24 @@ def get_available_instruction_templates():
|
|||||||
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_extensions():
|
def get_available_extensions():
|
||||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_softprompts():
|
def get_available_softprompts():
|
||||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_loras():
|
def get_available_loras():
|
||||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
shared.model = shared.tokenizer = None
|
shared.model = shared.tokenizer = None
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
def load_model_wrapper(selected_model):
|
def load_model_wrapper(selected_model):
|
||||||
if selected_model != shared.model_name:
|
if selected_model != shared.model_name:
|
||||||
shared.model_name = selected_model
|
shared.model_name = selected_model
|
||||||
@ -84,10 +94,12 @@ def load_model_wrapper(selected_model):
|
|||||||
|
|
||||||
return selected_model
|
return selected_model
|
||||||
|
|
||||||
|
|
||||||
def load_lora_wrapper(selected_lora):
|
def load_lora_wrapper(selected_lora):
|
||||||
add_lora_to_model(selected_lora)
|
add_lora_to_model(selected_lora)
|
||||||
return selected_lora
|
return selected_lora
|
||||||
|
|
||||||
|
|
||||||
def load_preset_values(preset_menu, state, return_dict=False):
|
def load_preset_values(preset_menu, state, return_dict=False):
|
||||||
generate_params = {
|
generate_params = {
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
@ -118,6 +130,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
|||||||
state.update(generate_params)
|
state.update(generate_params)
|
||||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||||
|
|
||||||
|
|
||||||
def upload_soft_prompt(file):
|
def upload_soft_prompt(file):
|
||||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||||
zf.extract('meta.json')
|
zf.extract('meta.json')
|
||||||
@ -130,12 +143,14 @@ def upload_soft_prompt(file):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def save_prompt(text):
|
def save_prompt(text):
|
||||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
||||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||||
f.write(text)
|
f.write(text)
|
||||||
return f"Saved to prompts/{fname}"
|
return f"Saved to prompts/{fname}"
|
||||||
|
|
||||||
|
|
||||||
def load_prompt(fname):
|
def load_prompt(fname):
|
||||||
if fname in ['None', '']:
|
if fname in ['None', '']:
|
||||||
return ''
|
return ''
|
||||||
@ -146,6 +161,7 @@ def load_prompt(fname):
|
|||||||
text = text[:-1]
|
text = text[:-1]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def create_prompt_menus():
|
def create_prompt_menus():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -161,6 +177,7 @@ def create_prompt_menus():
|
|||||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
def create_model_menus():
|
def create_model_menus():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -175,6 +192,7 @@ def create_model_menus():
|
|||||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||||
|
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||||
@ -209,7 +227,6 @@ def create_settings_menus(default_preset):
|
|||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Contrastive search')
|
gr.Markdown('Contrastive search')
|
||||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -219,7 +236,6 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
|
||||||
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
with gr.Accordion('Soft prompt', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
||||||
@ -233,6 +249,7 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||||
|
|
||||||
|
|
||||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||||
cmd_list = vars(shared.args)
|
cmd_list = vars(shared.args)
|
||||||
@ -251,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
|
|||||||
|
|
||||||
shared.need_restart = True
|
shared.need_restart = True
|
||||||
|
|
||||||
|
|
||||||
available_models = get_available_models()
|
available_models = get_available_models()
|
||||||
available_presets = get_available_presets()
|
available_presets = get_available_presets()
|
||||||
available_characters = get_available_characters()
|
available_characters = get_available_characters()
|
||||||
@ -299,8 +317,8 @@ else:
|
|||||||
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
||||||
title = 'Text generation web UI'
|
title = 'Text generation web UI'
|
||||||
|
|
||||||
def create_interface():
|
|
||||||
|
|
||||||
|
def create_interface():
|
||||||
gen_events = []
|
gen_events = []
|
||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
extensions_module.load_extensions()
|
extensions_module.load_extensions()
|
||||||
@ -562,6 +580,7 @@ def create_interface():
|
|||||||
else:
|
else:
|
||||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
|
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
|
||||||
|
|
||||||
|
|
||||||
create_interface()
|
create_interface()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
Loading…
Reference in New Issue
Block a user