Merge branch 'main' into main

This commit is contained in:
Alexander Hristov Hristov 2023-03-13 19:50:08 +02:00 committed by GitHub
commit 63c5a139a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 158 additions and 81 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@ -0,0 +1 @@
ko_fi: oobabooga

View File

@ -27,7 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [Supports the LLaMA model](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). * [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* Supports softprompts. * Supports softprompts.
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions). * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
@ -60,11 +60,13 @@ pip3 install torch torchvision torchaudio --extra-index-url https://download.pyt
conda install pytorch torchvision torchaudio git -c pytorch conda install pytorch torchvision torchaudio git -c pytorch
``` ```
See also: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
## Installation option 2: one-click installers ## Installation option 2: one-click installers
[oobabooga-windows.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-windows.zip) [oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
[oobabooga-linux.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-linux.zip) [oobabooga-linux.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-linux.zip)
Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder. Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
@ -139,7 +141,7 @@ Optionally, you can use the following command-line flags:
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.| | `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
| `--gptq-bits` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA. | | `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
@ -155,12 +157,13 @@ Optionally, you can use the following command-line flags:
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. | | `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | | `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
| `--no-stream` | Don't stream the text output in real time. This improves the text generation performance.| | `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--listen` | Make the web UI reachable from your local network.| | `--listen` | Make the web UI reachable from your local network.|
| `--listen-port LISTEN_PORT` | The listening port that the server will use. | | `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--verbose` | Print the prompts to the terminal. | | `--verbose` | Print the prompts to the terminal. |
Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
@ -179,14 +182,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
Pull requests, suggestions, and issue reports are welcome. Pull requests, suggestions, and issue reports are welcome.
Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above. Before reporting a bug, make sure that you have:
These issues are known: 1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.
* 8-bit doesn't work properly on Windows or older GPUs.
* DeepSpeed doesn't work properly on Windows.
For these two, please try commenting on an existing issue instead of creating a new one.
## Credits ## Credits

View File

@ -1,8 +1,12 @@
import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
import modules.chat as chat
import modules.shared as shared
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
params = { params = {
@ -12,10 +16,28 @@ params = {
'model_id': 'v3_en', 'model_id': 'v3_en',
'sample_rate': 48000, 'sample_rate': 48000,
'device': 'cpu', 'device': 'cpu',
'show_text': False,
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
} }
current_params = params.copy() current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
wav_idx = 0 voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
})
def xmlesc(txt):
return txt.translate(table)
def load_model(): def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
@ -33,12 +55,32 @@ def remove_surrounded_chars(string):
new_string += char new_string += char
return new_string return new_string
def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def toggle_text_in_history(name1, name2):
for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1]
if visible_reply.startswith('<audio'):
if params['show_text']:
reply = shared.history['internal'][i][1]
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
they are fed into the model. they are fed into the model.
""" """
# Remove autoplay from the last reply
if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0:
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
return string return string
def output_modifier(string): def output_modifier(string):
@ -46,7 +88,7 @@ def output_modifier(string):
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
global wav_idx, model, current_params global model, current_params
for i in params: for i in params:
if params[i] != current_params[i]: if params[i] != current_params[i]:
@ -57,6 +99,7 @@ def output_modifier(string):
if params['activate'] == False: if params['activate'] == False:
return string return string
original_string = string
string = remove_surrounded_chars(string) string = remove_surrounded_chars(string)
string = string.replace('"', '') string = string.replace('"', '')
string = string.replace('', '') string = string.replace('', '')
@ -64,13 +107,17 @@ def output_modifier(string):
string = string.strip() string = string.strip()
if string == '': if string == '':
string = 'empty reply, try regenerating' string = '*Empty reply, try regenerating*'
else:
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav') autoplay = 'autoplay' if params['autoplay'] else ''
model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file)) string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
if params['show_text']:
string = f'<audio src="file/{output_file.as_posix()}" controls></audio>' string += f'\n\n{original_string}'
wav_idx += 1
return string return string
@ -85,9 +132,36 @@ def bot_prefix_modifier(string):
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate TTS') with gr.Accordion("Silero TTS"):
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts')
convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None) voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

View File

@ -25,10 +25,10 @@ class RWKVModel:
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
if shared.args.rwkv_strategy is None: if shared.args.rwkv_strategy is None:
model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}') model = RWKV(model=str(path), strategy=f'{device} {dtype}')
else: else:
model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy) model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, os.path.abspath(tokenizer_path)) pipeline = PIPELINE(model, str(tokenizer_path))
result = self() result = self()
result.pipeline = pipeline result.pipeline = pipeline
@ -61,7 +61,7 @@ class RWKVTokenizer:
@classmethod @classmethod
def from_pretrained(self, path): def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json" tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path)) tokenizer = Tokenizer.from_file(str(tokenizer_path))
result = self() result = self()
result.tokenizer = tokenizer result.tokenizer = tokenizer

View File

@ -22,6 +22,12 @@ def clean_chat_message(text):
text = text.strip() text = text.strip()
return text return text
def generate_chat_output(history, name1, name2, character):
if shared.args.cai_chat:
return generate_chat_html(history, name1, name2, character)
else:
return history
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
user_input = clean_chat_message(user_input) user_input = clean_chat_message(user_input)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
@ -53,7 +59,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False): def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False):
next_character_found = False next_character_found = False
substring_found = False
asker = name1 if not impersonate else name2 asker = name1 if not impersonate else name2
replier = name2 if not impersonate else name1 replier = name2 if not impersonate else name1
@ -79,15 +84,15 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
next_character_found = True next_character_found = True
reply = clean_chat_message(reply) reply = clean_chat_message(reply)
# Detect if something like "\nYo" is generated just before # If something like "\nYo" is generated just before "\nYou:"
# "\nYou:" is completed # is completed, trim it
tmp = f"\n{asker}:" next_turn = f"\n{asker}:"
for j in range(1, len(tmp)): for j in range(len(next_turn)-1, 0, -1):
if reply[-j:] == tmp[:j]: if reply[-j:] == next_turn[:j]:
reply = reply[:-j] reply = reply[:-j]
substring_found = True break
return reply, next_character_found, substring_found return reply, next_character_found
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
@ -122,7 +127,6 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
if not regenerate: if not regenerate:
# Display user input and "*is typing...*" imediately
yield shared.history['visible']+[[visible_text, '*Is typing...*']] yield shared.history['visible']+[[visible_text, '*Is typing...*']]
# Generate # Generate
@ -131,7 +135,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply # Extracting the reply
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check) reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat: if shared.args.chat:
@ -148,7 +152,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
shared.history['internal'][-1] = [text, reply] shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply] shared.history['visible'][-1] = [visible_text, visible_reply]
if not substring_found and not shared.args.no_stream: if not shared.args.no_stream:
yield shared.history['visible'] yield shared.history['visible']
if next_character_found: if next_character_found:
break break
@ -163,15 +167,12 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
# Display "*is typing...*" imediately
yield '*Is typing...*'
reply = '' reply = ''
yield '*Is typing...*'
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
if not substring_found: yield reply
yield reply
if next_character_found: if next_character_found:
break break
yield reply yield reply
@ -182,21 +183,18 @@ def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
if shared.args.cai_chat: yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
yield shared.history['visible']
else: else:
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
yield generate_chat_output(shared.history['visible']+[[last_visible[0], '*Is typing...*']], name1, name2, shared.character)
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat: if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else: else:
shared.history['visible'][-1] = (last_visible[0], _history[-1][1]) shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
yield shared.history['visible'] yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
@ -204,6 +202,7 @@ def remove_last_message(name1, name2):
shared.history['internal'].pop() shared.history['internal'].pop()
else: else:
last = ['', ''] last = ['', '']
if shared.args.cai_chat: if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0] return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
else: else:
@ -223,10 +222,7 @@ def replace_last_reply(text, name1, name2):
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
shared.history['internal'][-1][1] = apply_extensions(text, "input") shared.history['internal'][-1][1] = apply_extensions(text, "input")
if shared.args.cai_chat: return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return shared.history['visible']
def clear_html(): def clear_html():
return generate_chat_html([], "", "", shared.character) return generate_chat_html([], "", "", shared.character)
@ -246,10 +242,8 @@ def clear_chat_log(name1, name2):
else: else:
shared.history['internal'] = [] shared.history['internal'] = []
shared.history['visible'] = [] shared.history['visible'] = []
if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character) return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
else:
return shared.history['visible']
def redraw_html(name1, name2): def redraw_html(name1, name2):
return generate_chat_html(shared.history['visible'], name1, name2, shared.character) return generate_chat_html(shared.history['visible'], name1, name2, shared.character)

View File

@ -1,4 +1,3 @@
import os
import sys import sys
from pathlib import Path from pathlib import Path
@ -7,7 +6,7 @@ import torch
import modules.shared as shared import modules.shared as shared
sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
from llama import load_quant from llama import load_quant
@ -41,9 +40,9 @@ def load_quantized_LLaMA(model_name):
print(f"Could not find {pt_model}, exiting...") print(f"Could not find {pt_model}, exiting...")
exit() exit()
model = load_quant(path_to_model, os.path.abspath(pt_path), bits) model = load_quant(str(path_to_model), str(pt_path), bits)
# Multi-GPU setup # Multiple GPUs or GPU+CPU
if shared.args.gpu_memory: if shared.args.gpu_memory:
max_memory = {} max_memory = {}
for i in range(len(shared.args.gpu_memory)): for i in range(len(shared.args.gpu_memory)):

View File

@ -85,12 +85,12 @@ parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory t
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".') parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.') parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
args = parser.parse_args() args = parser.parse_args()

View File

@ -37,9 +37,13 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
return input_ids.cuda() return input_ids.cuda()
def decode(output_ids): def decode(output_ids):
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) # Open Assistant relies on special tokens like <|endoftext|>
reply = reply.replace(r'<|endoftext|>', '') if re.match('oasst-*', shared.model_name.lower()):
return reply return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
def generate_softprompt_input_tensors(input_ids): def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids) inputs_embeds = shared.model.transformer.wte(input_ids)
@ -119,7 +123,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
stopping_criteria_list = transformers.StoppingCriteriaList() stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None: if stopping_string is not None:
# Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
@ -129,7 +135,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not shared.args.flexgen: if not shared.args.flexgen:
generate_params = [ generate_params = [
f"max_new_tokens=max_new_tokens", f"max_new_tokens=max_new_tokens",
f"eos_token_id={n}", f"eos_token_id={eos_token_ids}",
f"stopping_criteria=stopping_criteria_list", f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}", f"do_sample={do_sample}",
f"temperature={temperature}", f"temperature={temperature}",
@ -149,7 +155,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}", f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
f"do_sample={do_sample}", f"do_sample={do_sample}",
f"temperature={temperature}", f"temperature={temperature}",
f"stop={n}", f"stop={eos_token_ids[-1]}",
] ]
if shared.args.deepspeed: if shared.args.deepspeed:
generate_params.append("synced_gpus=True") generate_params.append("synced_gpus=True")
@ -196,10 +202,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
if output[-1] == n: yield formatted_outputs(reply, shared.model_name)
break
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
@ -213,15 +221,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0])) input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
yield formatted_outputs(reply, shared.model_name)
finally: finally:
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")

View File

@ -1,12 +1,12 @@
accelerate==0.16.0 accelerate==0.17.0
bitsandbytes==0.37.0 bitsandbytes==0.37.0
flexgen==0.1.7 flexgen==0.1.7
gradio==3.18.0 gradio==3.18.0
numpy numpy
requests requests
rwkv==0.1.0 rwkv==0.3.1
safetensors==0.2.8 safetensors==0.3.0
sentencepiece sentencepiece
tqdm tqdm
markdown markdown
git+https://github.com/zphang/transformers@llama_push git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176

View File

@ -269,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)