From f91d3a3ff4a8d758e17fe419e9a12c3c1e54065d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 14:46:32 -0300 Subject: [PATCH 01/15] server.py readability --- server.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/server.py b/server.py index 740020ea..50305ec0 100644 --- a/server.py +++ b/server.py @@ -400,53 +400,53 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( - lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) ) gen_events.append(shared.gradio['textbox'].submit( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( - lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) ) gen_events.append(shared.gradio['Regenerate'].click( chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( - lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) ) shared.gradio['Replace last reply'].click( chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then( lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( - lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) shared.gradio['Clear history-confirm'].click( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then( chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then( - lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) shared.gradio['Stop'].click( - stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None).then( - chat.redraw_html, reload_inputs, [shared.gradio['display']]) + stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then( + chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['Chat mode'].change( lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then( - chat.redraw_html, reload_inputs, [shared.gradio['display']]) + chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['Instruction templates'].change( lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then( - chat.redraw_html, reload_inputs, [shared.gradio['display']]) + chat.redraw_html, reload_inputs, shared.gradio['display']) shared.gradio['upload_chat_history'].upload( - chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []).then( - chat.redraw_html, reload_inputs, [shared.gradio['display']]) + chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then( + chat.redraw_html, reload_inputs, shared.gradio['display']) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) - shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) + shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) - shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) + shared.gradio['download_button'].click(chat.save_history, inputs=None, outputs=[shared.gradio['download']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) @@ -454,7 +454,7 @@ def create_interface(): shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None) - shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True) + shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True) elif shared.args.notebook: with gr.Tab("Text generation", elem_id="main"): @@ -488,7 +488,7 @@ def create_interface(): output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) + shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") else: @@ -522,7 +522,7 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) + shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") with gr.Tab("Model", elem_id="model-tab"): From 34ec02d41d6b68aa04cffed1a70c30c977a27578 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 16:59:59 -0300 Subject: [PATCH 02/15] Make download-model.py importable --- download-model.py | 224 ++++++++++++++++++++++++---------------------- 1 file changed, 119 insertions(+), 105 deletions(-) diff --git a/download-model.py b/download-model.py index 880eeb40..cd5a3f31 100644 --- a/download-model.py +++ b/download-model.py @@ -2,7 +2,7 @@ Downloads models from Hugging Face to models/model-name. Example: -python download-model.py facebook/opt-1.3b +python download_model.py facebook/opt-1.3b ''' @@ -19,6 +19,7 @@ import requests import tqdm from tqdm.contrib.concurrent import thread_map + parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') @@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum args = parser.parse_args() -def get_file(url, output_folder): - filename = Path(url.rsplit('/', 1)[1]) - output_path = output_folder / filename - if output_path.exists() and not args.clean: - # Check if the file has already been downloaded completely - r = requests.get(url, stream=True) - total_size = int(r.headers.get('content-length', 0)) - if output_path.stat().st_size >= total_size: - return - # Otherwise, resume the download from where it left off - headers = {'Range': f'bytes={output_path.stat().st_size}-'} - mode = 'ab' - else: - headers = {} - mode = 'wb' - - r = requests.get(url, stream=True, headers=headers) - with open(output_path, mode) as f: - total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 - with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: - for data in r.iter_content(block_size): - t.update(len(data)) - f.write(data) - - -def sanitize_branch_name(branch_name): - pattern = re.compile(r"^[a-zA-Z0-9._-]+$") - if pattern.match(branch_name): - return branch_name - else: - raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") - - def select_model_from_default_options(): models = { "OPT 6.7B": ("facebook", "opt-6.7b", "main"), @@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped return model, branch -def get_download_links_from_huggingface(model, branch): +def sanitize_model_and_branch_names(model, branch): + if model[-1] == '/': + model = model[:-1] + if branch is None: + branch = "main" + else: + pattern = re.compile(r"^[a-zA-Z0-9._-]+$") + if not pattern.match(branch): + raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") + + return model, branch + + +def get_download_links_from_huggingface(model, branch, text_only=False): base = "https://huggingface.co" page = f"/api/models/{model}/tree/{branch}?cursor=" cursor = b"" @@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch): links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") classifications.append('text') continue - if not args.text_only: + if not text_only: links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") if is_safetensors: has_safetensors = True @@ -177,80 +157,114 @@ def get_download_links_from_huggingface(model, branch): return links, sha256, is_lora -def download_files(file_list, output_folder, num_threads=8): - thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) - - -if __name__ == '__main__': - model = args.MODEL - branch = args.branch - if model is None: - model, branch = select_model_from_default_options() - else: - if model[-1] == '/': - model = model[:-1] - branch = args.branch - if branch is None: - branch = "main" - else: - try: - branch = sanitize_branch_name(branch) - except ValueError as err_branch: - print(f"Error: {err_branch}") - sys.exit() - - links, sha256, is_lora = get_download_links_from_huggingface(model, branch) - - if args.output is not None: - base_folder = args.output - else: +def get_output_folder(model, branch, is_lora, base_folder=None): + if base_folder is None: base_folder = 'models' if not is_lora else 'loras' output_folder = f"{'_'.join(model.split('/')[-2:])}" if branch != 'main': output_folder += f'_{branch}' output_folder = Path(base_folder) / output_folder + return output_folder + + +def get_single_file(url, output_folder, start_from_scratch=False): + filename = Path(url.rsplit('/', 1)[1]) + output_path = output_folder / filename + if output_path.exists() and not start_from_scratch: + # Check if the file has already been downloaded completely + r = requests.get(url, stream=True) + total_size = int(r.headers.get('content-length', 0)) + if output_path.stat().st_size >= total_size: + return + # Otherwise, resume the download from where it left off + headers = {'Range': f'bytes={output_path.stat().st_size}-'} + mode = 'ab' + else: + headers = {} + mode = 'wb' + + r = requests.get(url, stream=True, headers=headers) + with open(output_path, mode) as f: + total_size = int(r.headers.get('content-length', 0)) + block_size = 1024 + with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) + + +def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1): + thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) + + +def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1): + # Creating the folder and writing the metadata + if not output_folder.exists(): + output_folder.mkdir() + with open(output_folder / 'huggingface-metadata.txt', 'w') as f: + f.write(f'url: https://huggingface.co/{model}\n') + f.write(f'branch: {branch}\n') + f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') + sha256_str = '' + for i in range(len(sha256)): + sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' + if sha256_str != '': + f.write(f'sha256sum:\n{sha256_str}') + + # Downloading the files + print(f"Downloading the model to {output_folder}") + start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads) + + +def check_model_files(model, branch, links, sha256, output_folder): + # Validate the checksums + validated = True + for i in range(len(sha256)): + fpath = (output_folder / sha256[i][0]) + + if not fpath.exists(): + print(f"The following file is missing: {fpath}") + validated = False + continue + + with open(output_folder / sha256[i][0], "rb") as f: + bytes = f.read() + file_hash = hashlib.sha256(bytes).hexdigest() + if file_hash != sha256[i][1]: + print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}') + validated = False + else: + print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') + + if validated: + print('[+] Validated checksums of all model files!') + else: + print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.') + + +if __name__ == '__main__': + branch = args.branch + model = args.MODEL + if model is None: + model, branch = select_model_from_default_options() + + # Cleaning up the model/branch names + try: + model, branch = sanitize_model_and_branch_names(model, branch) + except ValueError as err_branch: + print(f"Error: {err_branch}") + sys.exit() + + # Getting the download links from Hugging Face + links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only) + + # Getting the output folder + output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output) if args.check: - # Validate the checksums - validated = True - for i in range(len(sha256)): - fpath = (output_folder / sha256[i][0]) - - if not fpath.exists(): - print(f"The following file is missing: {fpath}") - validated = False - continue - - with open(output_folder / sha256[i][0], "rb") as f: - bytes = f.read() - file_hash = hashlib.sha256(bytes).hexdigest() - if file_hash != sha256[i][1]: - print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}') - validated = False - else: - print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') - - if validated: - print('[+] Validated checksums of all model files!') - else: - print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.') - + # Check previously downloaded files + check_model_files(model, branch, links, sha256, output_folder) else: - - # Creating the folder and writing the metadata - if not output_folder.exists(): - output_folder.mkdir() - with open(output_folder / 'huggingface-metadata.txt', 'w') as f: - f.write(f'url: https://huggingface.co/{model}\n') - f.write(f'branch: {branch}\n') - f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') - sha256_str = '' - for i in range(len(sha256)): - sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' - if sha256_str != '': - f.write(f'sha256sum:\n{sha256_str}') - - # Downloading the files - print(f"Downloading the model to {output_folder}") - download_files(links, output_folder, args.threads) + # Download files + download_model_files(model, branch, links, sha256, output_folder, threads=args.threads) From 170e0c05c427727382d675bdf72866e5515aa734 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 17:00:59 -0300 Subject: [PATCH 03/15] Typo --- download-model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/download-model.py b/download-model.py index cd5a3f31..a48a1b8c 100644 --- a/download-model.py +++ b/download-model.py @@ -2,7 +2,7 @@ Downloads models from Hugging Face to models/model-name. Example: -python download_model.py facebook/opt-1.3b +python download-model.py facebook/opt-1.3b ''' From d29f4624e957e52c35d07ffc515a56c440de63c1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 20:04:16 -0300 Subject: [PATCH 04/15] Add a Continue button to chat mode --- modules/chat.py | 47 +++++++++++++++++++++++++++++++++++++---------- server.py | 9 +++++++-- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 6c6077a2..6400adcb 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False + _continue = kwargs['_continue'] if '_continue' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False rows = [f"{context.strip()}\n"] @@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat i = len(shared.history['internal']) - 1 while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") + if _continue and i == len(shared.history['internal']) - 1: + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") + else: + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") string = shared.history['internal'][i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") @@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat if impersonate: rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") limit = 2 + elif _continue: + limit = 3 else: # Adding the user message user_input = fix_newlines(user_input) @@ -56,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat # Adding the Character prefix rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) + limit = 3 while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: rows.pop(1) prompt = ''.join(rows) - if also_return_rows: return prompt, rows else: @@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): return reply, next_character_found -def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): +def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False): if mode == 'instruct': stopping_strings = [f"\n{name1}", f"\n{name2}"] else: @@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu # Defining some variables cumulative_reply = '' + last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None just_started = True name1_original = name1 visible_text = custom_generate_chat_prompt = None @@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu if visible_text is None: visible_text = text - text = apply_extensions(text, "input") + if not _continue: + text = apply_extensions(text, "input") # Generating the prompt - kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} + kwargs = { + 'end_of_turn': end_of_turn, + 'is_instruct': mode == 'instruct', + '_continue': _continue + } if custom_generate_chat_prompt is None: prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) else: prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) # Yield *Is typing...* - if not regenerate: + if not any((regenerate, _continue)): yield shared.history['visible'] + [[visible_text, shared.processing_message]] # Generate @@ -154,11 +166,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu return shared.history['visible'] if just_started: just_started = False - shared.history['internal'].append(['', '']) - shared.history['visible'].append(['', '']) + if not _continue: + shared.history['internal'].append(['', '']) + shared.history['visible'].append(['', '']) - shared.history['internal'][-1] = [text, reply] - shared.history['visible'][-1] = [visible_text, visible_reply] + if _continue: + shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}'] + shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}'] + else: + shared.history['internal'][-1] = [text, reply] + shared.history['visible'][-1] = [visible_text, visible_reply] if not shared.args.no_stream: yield shared.history['visible'] if next_character_found: @@ -220,6 +237,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) +def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) + else: + # Yield ' ...' + yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode) + for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True): + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) + + def remove_last_message(name1, name2, mode): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': last = shared.history['visible'].pop() diff --git a/server.py b/server.py index 50305ec0..cbfbd241 100644 --- a/server.py +++ b/server.py @@ -327,8 +327,9 @@ def create_interface(): shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") with gr.Row(): - shared.gradio['Impersonate'] = gr.Button('Impersonate') shared.gradio['Regenerate'] = gr.Button('Regenerate') + shared.gradio['Continue'] = gr.Button('Continue') + shared.gradio['Impersonate'] = gr.Button('Impersonate') with gr.Row(): shared.gradio['Copy last reply'] = gr.Button('Copy last reply') shared.gradio['Replace last reply'] = gr.Button('Replace last reply') @@ -411,7 +412,11 @@ def create_interface(): gen_events.append(shared.gradio['Regenerate'].click( chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( - lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) + ) + + gen_events.append(shared.gradio['Continue'].click( + chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( lambda: chat.save_history(timestamp=False), None, None, show_progress=False) ) From b27d757fd149683bd7f14449c3ad360f9f4fc1e7 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 20:06:20 -0300 Subject: [PATCH 05/15] Minor change --- modules/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/chat.py b/modules/chat.py index 6400adcb..edd82216 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -62,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat # Adding the Character prefix rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) - limit = 3 while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: rows.pop(1) prompt = ''.join(rows) + if also_return_rows: return prompt, rows else: From 120f5662cf085d2cdaf387690a156d85ca7d131b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 20:37:31 -0300 Subject: [PATCH 06/15] Better handle spaces for Continue --- modules/chat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index edd82216..6a1f7ad1 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -171,8 +171,9 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu shared.history['visible'].append(['', '']) if _continue: - shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}'] - shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}'] + sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply)) + shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}'] + shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}'] else: shared.history['internal'][-1] = [text, reply] shared.history['visible'][-1] = [visible_text, visible_reply] From a3085dba073fe8bdcfb5120729a84560f5d024c3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 21:19:39 -0300 Subject: [PATCH 07/15] Fix LlamaTokenizer eos_token (attempt) --- modules/models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/models.py b/modules/models.py index 75dff911..37e8e78e 100644 --- a/modules/models.py +++ b/modules/models.py @@ -174,6 +174,9 @@ def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) elif type(model) is transformers.LlamaForCausalLM: tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' From 57f768eaad73db18f09d40545cea4247269b3696 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 22:18:40 -0300 Subject: [PATCH 08/15] Better preset in api-example.py --- api-example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api-example.py b/api-example.py index 10be0a88..a3aaa0f0 100644 --- a/api-example.py +++ b/api-example.py @@ -22,10 +22,10 @@ server = "127.0.0.1" params = { 'max_new_tokens': 200, 'do_sample': True, - 'temperature': 0.5, - 'top_p': 0.9, + 'temperature': 0.72, + 'top_p': 0.73, 'typical_p': 1, - 'repetition_penalty': 1.05, + 'repetition_penalty': 1.1, 'encoder_repetition_penalty': 1.0, 'top_k': 0, 'min_length': 0, From 625d81f495e57961c0e63fa8b4c40835e0f33b94 Mon Sep 17 00:00:00 2001 From: Brian O'Connor Date: Sun, 9 Apr 2023 21:20:21 -0400 Subject: [PATCH 09/15] Update character log logic (#977) * When logs are cleared, save the cleared log over the old log files * Generate a log file when a character is loaded the first time --- modules/chat.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 6a1f7ad1..df39a58b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -284,6 +284,9 @@ def clear_chat_log(name1, name2, greeting, mode): if greeting != '': shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] + + # Save cleared logs + save_history(timestamp=False) return chat_html_wrapper(shared.history['visible'], name1, name2, mode) @@ -434,9 +437,14 @@ def load_character(character, name1, name2, mode): if Path(f'logs/{shared.character}_persistent.json').exists(): load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) - elif greeting != "": - shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] - shared.history['visible'] += [['', apply_extensions(greeting, "output")]] + else: + # Insert greeting if it exists + if greeting != "": + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] + shared.history['visible'] += [['', apply_extensions(greeting, "output")]] + + # Create .json log files since they don't already exist + save_history(timestamp=False) return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) From 992663fa20b0c19f4b9156639a01a9810c992493 Mon Sep 17 00:00:00 2001 From: MarkovInequality <31809330+MarkovInequality@users.noreply.github.com> Date: Sun, 9 Apr 2023 22:08:40 -0400 Subject: [PATCH 10/15] Added xformers support to Llama (#950) --- README.md | 2 + modules/llama_attn_hijack.py | 176 +++++++++++++++++++++++++++++++++++ modules/models.py | 5 + modules/shared.py | 2 + 4 files changed, 185 insertions(+) create mode 100644 modules/llama_attn_hijack.py diff --git a/README.md b/README.md index e7cefef4..136a6d60 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,8 @@ Optionally, you can use the following command-line flags: | `--load-in-8bit` | Load the model with 8-bit precision.| | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. | +| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. | +| `--sdp-attention` | Use torch 2.0's sdp attention. | #### llama.cpp diff --git a/modules/llama_attn_hijack.py b/modules/llama_attn_hijack.py new file mode 100644 index 00000000..f5c5c92e --- /dev/null +++ b/modules/llama_attn_hijack.py @@ -0,0 +1,176 @@ +import math +import sys +import torch +import torch.nn as nn +import transformers.models.llama.modeling_llama + +from typing import Optional +from typing import Tuple + +import modules.shared as shared + + +if shared.args.xformers: + try: + import xformers.ops + except Exception: + print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr) + + +def hijack_llama_attention(): + if shared.args.xformers: + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + print("Replaced attention with xformers_attention") + elif shared.args.sdp_attention: + transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward + print("Replaced attention with sdp_attention") + + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + #We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + dtype = query_states.dtype + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask()) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +def sdp_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + #We only apply sdp attention if we don't need to output the whole attention matrix + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value diff --git a/modules/models.py b/modules/models.py index 37e8e78e..4e892aa9 100644 --- a/modules/models.py +++ b/modules/models.py @@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer) import modules.shared as shared +from modules import llama_attn_hijack transformers.logging.set_verbosity_error() @@ -169,6 +170,10 @@ def load_model(model_name): model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) + # Hijack attention with xformers + if any((shared.args.xformers, shared.args.sdp_attention)): + llama_attn_hijack.hijack_llama_attention() + # Loading the tokenizer if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) diff --git a/modules/shared.py b/modules/shared.py index 7ff1ca28..663ed498 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -98,6 +98,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.') +parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.") +parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.") # llama.cpp parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.') From 8c6155251ae9852bbae1fd4df40934988c86a0b1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 23:19:28 -0300 Subject: [PATCH 11/15] More robust 4-bit model loading --- modules/GPTQ_loader.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 3f42e5c6..aa6aec7a 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -100,10 +100,10 @@ def load_quantized(model_name): found_safetensors = list(path_to_model.glob("*.safetensors")) pt_path = None - if len(found_pts) == 1: - pt_path = found_pts[0] - elif len(found_safetensors) == 1: - pt_path = found_safetensors[0] + if len(found_pts) > 0: + pt_path = found_pts[-1] + elif len(found_safetensors) > 0: + pt_path = found_safetensors[-1] else: if path_to_model.name.lower().startswith('llama-7b'): pt_model = f'llama-7b-{shared.args.wbits}bit' @@ -119,13 +119,14 @@ def load_quantized(model_name): # Try to find the .safetensors or .pt both in the model dir and in the subfolder for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: if path.exists(): - print(f"Found {path}") pt_path = path break if not pt_path: print("Could not find the quantized model in .pt or .safetensors format, exiting...") exit() + else: + print(f"Found the following quantized model: {pt_path}") # qwopqwop200's offload if model_type == 'llama' and shared.args.pre_layer: From dba2000d2bf605dc0787e63b1cfb3e6c7722e4bd Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 23:40:17 -0300 Subject: [PATCH 12/15] Do things that I am not proud of --- modules/models.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/models.py b/modules/models.py index 4e892aa9..32c9c348 100644 --- a/modules/models.py +++ b/modules/models.py @@ -179,9 +179,14 @@ def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) elif type(model) is transformers.LlamaForCausalLM: tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) - tokenizer.eos_token_id = 2 - tokenizer.bos_token_id = 1 - tokenizer.pad_token_id = 0 + # Leaving this here until the LLaMA tokenizer gets figured out. + # For some people this fixes things, for others it causes an error. + try: + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 + except: + continue else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' From 8178fde2cb91ee0d8705de021021dcb5bbc008b2 Mon Sep 17 00:00:00 2001 From: BlueprintCoding <130100872+BlueprintCoding@users.noreply.github.com> Date: Sun, 9 Apr 2023 20:44:31 -0600 Subject: [PATCH 13/15] Added dropdown to character bias. (#986) --- extensions/character_bias/script.py | 48 +++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/extensions/character_bias/script.py b/extensions/character_bias/script.py index a92d0aef..614d9ce3 100644 --- a/extensions/character_bias/script.py +++ b/extensions/character_bias/script.py @@ -1,8 +1,23 @@ import gradio as gr +import os + +# get the current directory of the script +current_dir = os.path.dirname(os.path.abspath(__file__)) + +# check if the bias_options.txt file exists, if not, create it +bias_file = os.path.join(current_dir, "bias_options.txt") +if not os.path.isfile(bias_file): + with open(bias_file, "w") as f: + f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*") + +# read bias options from the text file +with open(bias_file, "r") as f: + bias_options = [line.strip() for line in f.readlines()] params = { "activate": True, "bias string": " *I am so happy*", + "use custom string": False, } @@ -11,7 +26,6 @@ def input_modifier(string): This function is applied to your text inputs before they are fed into the model. """ - return string @@ -19,7 +33,6 @@ def output_modifier(string): """ This function is applied to the model outputs. """ - return string @@ -29,9 +42,11 @@ def bot_prefix_modifier(string): the prefix text for the Bot and can be used to bias its behavior. """ - if params['activate']: - return f'{string} {params["bias string"].strip()} ' + if params['use custom string']: + return f'{string} {params["custom string"].strip()} ' + else: + return f'{string} {params["bias string"].strip()} ' else: return string @@ -39,8 +54,29 @@ def bot_prefix_modifier(string): def ui(): # Gradio elements activate = gr.Checkbox(value=params['activate'], label='Activate character bias') - string = gr.Textbox(value=params["bias string"], label='Character bias') + dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file') + use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown') + custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above') # Event functions to update the parameters in the backend - string.change(lambda x: params.update({"bias string": x}), string, None) + def update_bias_string(x): + if x: + params.update({"bias string": x}) + else: + params.update({"bias string": dropdown_string.get()}) + return x + + def update_custom_string(x): + params.update({"custom string": x}) + + dropdown_string.change(update_bias_string, dropdown_string, None) + custom_string.change(update_custom_string, custom_string, None) activate.change(lambda x: params.update({"activate": x}), activate, None) + use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None) + + # Group elements together depending on the selected option + def bias_string_group(): + if use_custom_string.value: + return gr.Group([use_custom_string, custom_string]) + else: + return dropdown_string From 1911504f8205bfae5203185d89d7860a0f930974 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 9 Apr 2023 23:45:41 -0300 Subject: [PATCH 14/15] Minor bug fix --- modules/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models.py b/modules/models.py index 32c9c348..7ec93df8 100644 --- a/modules/models.py +++ b/modules/models.py @@ -186,7 +186,7 @@ def load_model(model_name): tokenizer.bos_token_id = 1 tokenizer.pad_token_id = 0 except: - continue + pass else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' From 32d078487e2b7153e5c57e013d41148cf2806daa Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 10 Apr 2023 10:45:51 -0300 Subject: [PATCH 15/15] Add llama-cpp-python to requirements.txt --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ff23f486..4dd753ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ accelerate==0.18.0 -bitsandbytes==0.37.2 datasets flexgen==0.1.7 gradio==3.24.1 @@ -14,3 +13,6 @@ sentencepiece pyyaml tqdm git+https://github.com/huggingface/transformers +bitsandbytes==0.37.2; platform_system != "Windows" +llama-cpp-python==0.1.30; platform_system != "Windows" +https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"