mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Merge branch 'oobabooga:main' into stt-extension
This commit is contained in:
commit
3b4145966d
17
README.md
17
README.md
@ -1,6 +1,6 @@
|
||||
# Text generation web UI
|
||||
|
||||
A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, GPT-Neo, and Pygmalion.
|
||||
A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion.
|
||||
|
||||
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
||||
|
||||
@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
|
||||
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
|
||||
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
|
||||
* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
|
||||
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
|
||||
* Supports softprompts.
|
||||
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
|
||||
@ -53,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
|
||||
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
```
|
||||
|
||||
* If you are running in CPU mode, replace the third command with this one:
|
||||
* If you are running it in CPU mode, replace the third command with this one:
|
||||
|
||||
```
|
||||
conda install pytorch torchvision torchaudio git -c pytorch
|
||||
@ -137,6 +138,8 @@ Optionally, you can use the following command-line flags:
|
||||
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
|
||||
| `--cpu` | Use the CPU to generate text.|
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
|
||||
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
|
||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||
@ -176,14 +179,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
|
||||
|
||||
Pull requests, suggestions, and issue reports are welcome.
|
||||
|
||||
Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above.
|
||||
Before reporting a bug, make sure that you have:
|
||||
|
||||
These issues are known:
|
||||
|
||||
* 8-bit doesn't work properly on Windows or older GPUs.
|
||||
* DeepSpeed doesn't work properly on Windows.
|
||||
|
||||
For these two, please try commenting on an existing issue instead of creating a new one.
|
||||
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
|
||||
2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.
|
||||
|
||||
## Credits
|
||||
|
||||
|
@ -5,7 +5,9 @@ Example:
|
||||
python download-model.py facebook/opt-1.3b
|
||||
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import multiprocessing
|
||||
import re
|
||||
@ -93,23 +95,28 @@ facebook/opt-1.3b
|
||||
def get_download_links_from_huggingface(model, branch):
|
||||
base = "https://huggingface.co"
|
||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||
cursor = b""
|
||||
|
||||
links = []
|
||||
classifications = []
|
||||
has_pytorch = False
|
||||
has_safetensors = False
|
||||
while page is not None:
|
||||
content = requests.get(f"{base}{page}").content
|
||||
while True:
|
||||
content = requests.get(f"{base}{page}{cursor.decode()}").content
|
||||
|
||||
dict = json.loads(content)
|
||||
if len(dict) == 0:
|
||||
break
|
||||
|
||||
for i in range(len(dict)):
|
||||
fname = dict[i]['path']
|
||||
|
||||
is_pytorch = re.match("pytorch_model.*\.bin", fname)
|
||||
is_safetensors = re.match("model.*\.safetensors", fname)
|
||||
is_text = re.match(".*\.(txt|json)", fname)
|
||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
|
||||
|
||||
if is_text or is_safetensors or is_pytorch:
|
||||
if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
|
||||
if is_text:
|
||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||
classifications.append('text')
|
||||
@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
|
||||
has_pytorch = True
|
||||
classifications.append('pytorch')
|
||||
|
||||
#page = dict['nextUrl']
|
||||
page = None
|
||||
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
|
||||
cursor = base64.b64encode(cursor)
|
||||
cursor = cursor.replace(b'=', b'%3D')
|
||||
|
||||
# If both pytorch and safetensors are available, download safetensors only
|
||||
if has_pytorch and has_safetensors:
|
||||
|
18
extensions/llama_prompts/script.py
Normal file
18
extensions/llama_prompts/script.py
Normal file
@ -0,0 +1,18 @@
|
||||
import gradio as gr
|
||||
import modules.shared as shared
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||
|
||||
def get_prompt_by_name(name):
|
||||
if name == 'None':
|
||||
return ''
|
||||
else:
|
||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||
|
||||
def ui():
|
||||
if not shared.args.chat or shared.args.cai_chat:
|
||||
choices = ['None'] + list(df['Prompt name'])
|
||||
|
||||
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
|
||||
prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])
|
@ -1,21 +1,45 @@
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
'speaker': 'en_56',
|
||||
'speaker': 'en_5',
|
||||
'language': 'en',
|
||||
'model_id': 'v3_en',
|
||||
'sample_rate': 48000,
|
||||
'device': 'cpu',
|
||||
'show_text': False,
|
||||
'autoplay': True,
|
||||
'voice_pitch': 'medium',
|
||||
'voice_speed': 'medium',
|
||||
}
|
||||
|
||||
current_params = params.copy()
|
||||
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||
wav_idx = 0
|
||||
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||
last_msg_id = 0
|
||||
|
||||
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||
table = str.maketrans({
|
||||
"<": "<",
|
||||
">": ">",
|
||||
"&": "&",
|
||||
"'": "'",
|
||||
'"': """,
|
||||
})
|
||||
|
||||
def xmlesc(txt):
|
||||
return txt.translate(table)
|
||||
|
||||
def load_model():
|
||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
@ -33,12 +57,59 @@ def remove_surrounded_chars(string):
|
||||
new_string += char
|
||||
return new_string
|
||||
|
||||
def remove_tts_from_history():
|
||||
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
reply = entry[1]
|
||||
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
|
||||
if shared.args.chat:
|
||||
reply = reply.replace('\n', '<br>')
|
||||
shared.history['visible'][i][1] = reply
|
||||
|
||||
if shared.args.cai_chat:
|
||||
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
|
||||
else:
|
||||
return shared.history['visible']
|
||||
|
||||
def toggle_text_in_history():
|
||||
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
||||
audio_str='\n\n' # The '\n\n' used after </audio>
|
||||
if shared.args.chat:
|
||||
audio_str='<br><br>'
|
||||
|
||||
if params['show_text']==True:
|
||||
#for i, entry in enumerate(shared.history['internal']):
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
vis_reply = entry[1]
|
||||
if vis_reply.startswith('<audio'):
|
||||
reply = shared.history['internal'][i][1]
|
||||
reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
|
||||
if shared.args.chat:
|
||||
reply = reply.replace('\n', '<br>')
|
||||
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply
|
||||
else:
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
vis_reply = entry[1]
|
||||
if vis_reply.startswith('<audio'):
|
||||
shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str
|
||||
|
||||
if shared.args.cai_chat:
|
||||
return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
|
||||
else:
|
||||
return shared.history['visible']
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
|
||||
# Remove autoplay from previous chat history
|
||||
if (shared.args.chat or shared.args.cai_chat)and len(shared.history['internal'])>0:
|
||||
[visible_text, visible_reply] = shared.history['visible'][-1]
|
||||
vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
|
||||
shared.history['visible'][-1] = [visible_text, vis_rep_clean]
|
||||
|
||||
return string
|
||||
|
||||
def output_modifier(string):
|
||||
@ -46,7 +117,7 @@ def output_modifier(string):
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
global wav_idx, model, current_params
|
||||
global model, current_params
|
||||
|
||||
for i in params:
|
||||
if params[i] != current_params[i]:
|
||||
@ -57,20 +128,34 @@ def output_modifier(string):
|
||||
if params['activate'] == False:
|
||||
return string
|
||||
|
||||
orig_string = string
|
||||
string = remove_surrounded_chars(string)
|
||||
string = string.replace('"', '')
|
||||
string = string.replace('“', '')
|
||||
string = string.replace('\n', ' ')
|
||||
string = string.strip()
|
||||
|
||||
silent_string = False # Used to prevent unnecessary audio file generation
|
||||
if string == '':
|
||||
string = 'empty reply, try regenerating'
|
||||
silent_string = True
|
||||
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
|
||||
model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
pitch = params['voice_pitch']
|
||||
speed = params['voice_speed']
|
||||
prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
|
||||
string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
|
||||
|
||||
string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
|
||||
wav_idx += 1
|
||||
if not shared.still_streaming and not silent_string:
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
|
||||
model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
autoplay_str = ' autoplay' if params['autoplay'] else ''
|
||||
string = f'<audio src="file/{output_file.as_posix()}" controls{autoplay_str}></audio>\n\n'
|
||||
else:
|
||||
# Placeholder so text doesn't shift around so much
|
||||
string = '<audio controls></audio>\n\n'
|
||||
|
||||
if params['show_text']:
|
||||
string += orig_string
|
||||
|
||||
return string
|
||||
|
||||
@ -85,9 +170,36 @@ def bot_prefix_modifier(string):
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
||||
with gr.Accordion("Silero TTS"):
|
||||
with gr.Row():
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
||||
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
||||
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
||||
with gr.Row():
|
||||
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
||||
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
||||
with gr.Row():
|
||||
convert = gr.Button('Permanently replace chat history audio with message text')
|
||||
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
|
||||
convert_cancel = gr.Button('Cancel', visible=False)
|
||||
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_confirm.click(remove_tts_from_history, [], shared.gradio['display'])
|
||||
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
|
||||
# Toggle message text in history
|
||||
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
||||
show_text.change(toggle_text_in_history, [], shared.gradio['display'])
|
||||
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
||||
voice.change(lambda x: params.update({"speaker": x}), voice, None)
|
||||
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
|
||||
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
||||
|
@ -1,12 +1,11 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import Iteratorize
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
|
||||
@ -49,11 +48,11 @@ class RWKVModel:
|
||||
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
iterable = Iteratorize(self.generate, kwargs, callback=None)
|
||||
reply = kwargs['context']
|
||||
for token in iterable:
|
||||
reply += token
|
||||
yield reply
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
reply = kwargs['context']
|
||||
for token in generator:
|
||||
reply += token
|
||||
yield reply
|
||||
|
||||
class RWKVTokenizer:
|
||||
def __init__(self):
|
||||
@ -73,38 +72,3 @@ class RWKVTokenizer:
|
||||
|
||||
def decode(self, ids):
|
||||
return self.tokenizer.decode(ids)
|
||||
|
||||
class Iteratorize:
|
||||
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc=func
|
||||
self.c_callback=callback
|
||||
self.q = Queue(maxsize=1)
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _callback(val):
|
||||
self.q.put(val)
|
||||
|
||||
def gentask():
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
self.q.put(self.sentinel)
|
||||
if self.c_callback:
|
||||
self.c_callback(ret)
|
||||
|
||||
Thread(target=gentask).start()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True,None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
98
modules/callbacks.py
Normal file
98
modules/callbacks.py
Normal file
@ -0,0 +1,98 @@
|
||||
import gc
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
||||
starting_idx: int):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
self.sentinel_token_ids = sentinel_token_ids
|
||||
self.starting_idx = starting_idx
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor,
|
||||
_scores: torch.FloatTensor) -> bool:
|
||||
for sample in input_ids:
|
||||
trimmed_sample = sample[self.starting_idx:]
|
||||
# Can't unfold, output is still too tiny. Skip.
|
||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
|
||||
continue
|
||||
|
||||
for window in trimmed_sample.unfold(
|
||||
0, self.sentinel_token_ids.shape[-1], 1):
|
||||
if torch.all(torch.eq(self.sentinel_token_ids, window)):
|
||||
return True
|
||||
return False
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, input_ids, scores) -> bool:
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
class Iteratorize:
|
||||
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc=func
|
||||
self.c_callback=callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
self.stop_now = False
|
||||
|
||||
def _callback(val):
|
||||
if self.stop_now:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
def gentask():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
clear_torch_cache()
|
||||
self.q.put(self.sentinel)
|
||||
if self.c_callback:
|
||||
self.c_callback(ret)
|
||||
|
||||
self.thread = Thread(target=gentask)
|
||||
self.thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True,None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
||||
def __del__(self):
|
||||
clear_torch_cache()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop_now = True
|
||||
clear_torch_cache()
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
||||
tmp = f"\n{asker}:"
|
||||
for j in range(1, len(tmp)):
|
||||
if reply[-j:] == tmp[:j]:
|
||||
reply = reply[:-j]
|
||||
substring_found = True
|
||||
|
||||
return reply, next_character_found, substring_found
|
||||
@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
shared.stop_everything = False
|
||||
just_started = True
|
||||
eos_token = '\n' if check else None
|
||||
@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||
|
||||
if not regenerate:
|
||||
# Display user input and "*is typing...*" imediately
|
||||
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
|
||||
|
||||
# Generate
|
||||
reply = ''
|
||||
for i in range(chat_generation_attempts):
|
||||
@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
||||
|
||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
||||
|
||||
# Display "*is typing...*" imediately
|
||||
yield '*Is typing...*'
|
||||
|
||||
reply = ''
|
||||
for i in range(chat_generation_attempts):
|
||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||
@ -291,7 +299,7 @@ def save_history(timestamp=True):
|
||||
fname = f"{prefix}persistent.json"
|
||||
if not Path('logs').exists():
|
||||
Path('logs').mkdir()
|
||||
with open(Path(f'logs/{fname}'), 'w') as f:
|
||||
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||
return Path(f'logs/{fname}')
|
||||
|
||||
@ -332,7 +340,7 @@ def load_character(_character, name1, name2):
|
||||
shared.history['visible'] = []
|
||||
if _character != 'None':
|
||||
shared.character = _character
|
||||
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
|
||||
data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
|
||||
name2 = data['char_name']
|
||||
if 'char_persona' in data and data['char_persona'] != '':
|
||||
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
|
||||
@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False):
|
||||
i += 1
|
||||
if tavern:
|
||||
outfile_name = f'TavernAI-{outfile_name}'
|
||||
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
|
||||
with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
|
||||
f.write(json_file)
|
||||
if img is not None:
|
||||
img = Image.open(io.BytesIO(img))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
@ -41,7 +42,7 @@ def load_model(model_name):
|
||||
shared.is_RWKV = model_name.lower().startswith('rwkv-')
|
||||
|
||||
# Default settings
|
||||
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
|
||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
||||
else:
|
||||
@ -86,6 +87,12 @@ def load_model(model_name):
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
# 4-bit LLaMA
|
||||
elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit:
|
||||
from modules.quantized_LLaMA import load_quantized_LLaMA
|
||||
|
||||
model = load_quantized_LLaMA(model_name)
|
||||
|
||||
# Custom
|
||||
else:
|
||||
command = "AutoModelForCausalLM.from_pretrained"
|
||||
|
60
modules/quantized_LLaMA.py
Normal file
60
modules/quantized_LLaMA.py
Normal file
@ -0,0 +1,60 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa")))
|
||||
from llama import load_quant
|
||||
|
||||
|
||||
# 4-bit LLaMA
|
||||
def load_quantized_LLaMA(model_name):
|
||||
if shared.args.load_in_4bit:
|
||||
bits = 4
|
||||
else:
|
||||
bits = shared.args.gptq_bits
|
||||
|
||||
path_to_model = Path(f'models/{model_name}')
|
||||
pt_model = ''
|
||||
if path_to_model.name.lower().startswith('llama-7b'):
|
||||
pt_model = f'llama-7b-{bits}bit.pt'
|
||||
elif path_to_model.name.lower().startswith('llama-13b'):
|
||||
pt_model = f'llama-13b-{bits}bit.pt'
|
||||
elif path_to_model.name.lower().startswith('llama-30b'):
|
||||
pt_model = f'llama-30b-{bits}bit.pt'
|
||||
elif path_to_model.name.lower().startswith('llama-65b'):
|
||||
pt_model = f'llama-65b-{bits}bit.pt'
|
||||
else:
|
||||
pt_model = f'{model_name}-{bits}bit.pt'
|
||||
|
||||
# Try to find the .pt both in models/ and in the subfolder
|
||||
pt_path = None
|
||||
for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
if path.exists():
|
||||
pt_path = path
|
||||
|
||||
if not pt_path:
|
||||
print(f"Could not find {pt_model}, exiting...")
|
||||
exit()
|
||||
|
||||
model = load_quant(path_to_model, os.path.abspath(pt_path), bits)
|
||||
|
||||
# Multi-GPU setup
|
||||
if shared.args.gpu_memory:
|
||||
max_memory = {}
|
||||
for i in range(len(shared.args.gpu_memory)):
|
||||
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
|
||||
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
|
||||
|
||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
|
||||
model = accelerate.dispatch_model(model, device_map=device_map)
|
||||
|
||||
# Single GPU
|
||||
else:
|
||||
model = model.to(torch.device('cuda:0'))
|
||||
|
||||
return model
|
@ -11,6 +11,7 @@ is_RWKV = False
|
||||
history = {'internal': [], 'visible': []}
|
||||
character = 'None'
|
||||
stop_everything = False
|
||||
still_streaming = False
|
||||
|
||||
# UI elements (buttons, sliders, HTML, etc)
|
||||
gradio = {}
|
||||
@ -42,12 +43,12 @@ settings = {
|
||||
'default': 'NovelAI-Sphinx Moth',
|
||||
'pygmalion-*': 'Pygmalion',
|
||||
'RWKV-*': 'Naive',
|
||||
'(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
|
||||
},
|
||||
'prompts': {
|
||||
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
|
||||
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
|
||||
'(rosey|chip|joi)_.*_instruct.*': 'User: \n'
|
||||
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
|
||||
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,6 +69,8 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
|
||||
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.')
|
||||
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.')
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
|
||||
@ -90,4 +93,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
|
||||
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
|
||||
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
|
||||
args = parser.parse_args()
|
||||
|
@ -1,32 +0,0 @@
|
||||
'''
|
||||
This code was copied from
|
||||
|
||||
https://github.com/PygmalionAI/gradio-ui/
|
||||
|
||||
'''
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
||||
starting_idx: int):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
self.sentinel_token_ids = sentinel_token_ids
|
||||
self.starting_idx = starting_idx
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor,
|
||||
_scores: torch.FloatTensor) -> bool:
|
||||
for sample in input_ids:
|
||||
trimmed_sample = sample[self.starting_idx:]
|
||||
# Can't unfold, output is still too tiny. Skip.
|
||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
|
||||
continue
|
||||
|
||||
for window in trimmed_sample.unfold(
|
||||
0, self.sentinel_token_ids.shape[-1], 1):
|
||||
if torch.all(torch.eq(self.sentinel_token_ids, window)):
|
||||
return True
|
||||
return False
|
@ -5,13 +5,13 @@ import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from tqdm import tqdm
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import (Iteratorize, Stream,
|
||||
_SentinelTokenStoppingCriteria)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.models import local_rank
|
||||
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
||||
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
if shared.is_RWKV:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds.")
|
||||
return
|
||||
else:
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
finally:
|
||||
t1 = time.time()
|
||||
output = encode(reply)[0]
|
||||
input_ids = encode(question)
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
|
||||
return
|
||||
|
||||
original_question = question
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
@ -113,24 +116,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
print(f"\n\n{question}\n--------------------\n")
|
||||
|
||||
input_ids = encode(question, max_new_tokens)
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
||||
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id]
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
if stopping_string is not None:
|
||||
# The stopping_criteria code below was copied from
|
||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||
# Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList([
|
||||
_SentinelTokenStoppingCriteria(
|
||||
sentinel_token_ids=t,
|
||||
starting_idx=len(input_ids[0])
|
||||
)
|
||||
])
|
||||
else:
|
||||
stopping_criteria_list = None
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||
|
||||
if not shared.args.flexgen:
|
||||
generate_params = [
|
||||
f"eos_token_id={n}",
|
||||
f"max_new_tokens=max_new_tokens",
|
||||
f"eos_token_id={eos_token_ids}",
|
||||
f"stopping_criteria=stopping_criteria_list",
|
||||
f"do_sample={do_sample}",
|
||||
f"temperature={temperature}",
|
||||
@ -147,44 +148,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
]
|
||||
else:
|
||||
generate_params = [
|
||||
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
|
||||
f"do_sample={do_sample}",
|
||||
f"temperature={temperature}",
|
||||
f"stop={n}",
|
||||
f"stop={eos_token_ids[-1]}",
|
||||
]
|
||||
if shared.args.deepspeed:
|
||||
generate_params.append("synced_gpus=True")
|
||||
if shared.args.no_stream:
|
||||
generate_params.append("max_new_tokens=max_new_tokens")
|
||||
else:
|
||||
generate_params.append("max_new_tokens=8")
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
||||
generate_params.insert(0, "filler_input_ids")
|
||||
generate_params.insert(0, "inputs=filler_input_ids")
|
||||
else:
|
||||
generate_params.insert(0, "input_ids")
|
||||
|
||||
# Generate the entire reply at once
|
||||
if shared.args.no_stream:
|
||||
with torch.no_grad():
|
||||
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = decode(output)
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Generate the reply 8 tokens at a time
|
||||
else:
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
for i in tqdm(range(max_new_tokens//8+1)):
|
||||
clear_torch_cache()
|
||||
generate_params.insert(0, "inputs=input_ids")
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
if shared.args.no_stream:
|
||||
with torch.no_grad():
|
||||
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
|
||||
if shared.soft_prompt:
|
||||
@ -193,16 +173,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
reply = decode(output)
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
if not shared.args.flexgen:
|
||||
if output[-1] == n:
|
||||
break
|
||||
input_ids = torch.reshape(output, (1, output.shape[0]))
|
||||
else:
|
||||
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
|
||||
break
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||
elif not shared.args.flexgen:
|
||||
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
def generate_with_callback(callback=None, **kwargs):
|
||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
shared.model.generate(**kwargs)
|
||||
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
shared.still_streaming = True
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
reply = decode(output)
|
||||
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
shared.still_streaming = False
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
shared.still_streaming = True
|
||||
for i in range(max_new_tokens//8+1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
reply = decode(output)
|
||||
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
|
||||
shared.still_streaming = False
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
finally:
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
|
||||
return
|
||||
|
@ -1,9 +1,11 @@
|
||||
accelerate==0.16.0
|
||||
accelerate==0.17.0
|
||||
bitsandbytes==0.37.0
|
||||
flexgen==0.1.7
|
||||
gradio==3.18.0
|
||||
numpy
|
||||
rwkv==0.1.0
|
||||
safetensors==0.2.8
|
||||
requests
|
||||
rwkv==0.3.1
|
||||
safetensors==0.3.0
|
||||
sentencepiece
|
||||
git+https://github.com/oobabooga/transformers@llama_push
|
||||
tqdm
|
||||
git+https://github.com/zphang/transformers@llama_push
|
||||
|
18
server.py
18
server.py
@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
|
||||
from modules.models import load_model, load_soft_prompt
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
|
||||
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||
@ -37,7 +34,7 @@ def get_available_models():
|
||||
if shared.args.flexgen:
|
||||
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
||||
else:
|
||||
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
|
||||
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
|
||||
|
||||
def get_available_presets():
|
||||
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
||||
@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
|
||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
@ -372,9 +370,9 @@ else:
|
||||
|
||||
shared.gradio['interface'].queue()
|
||||
if shared.args.listen:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port)
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||
else:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||
|
||||
# I think that I will need this later
|
||||
while True:
|
||||
|
@ -29,6 +29,7 @@
|
||||
"prompts": {
|
||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
|
||||
"(rosey|chip|joi)_.*_instruct.*": "User: \n"
|
||||
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
|
||||
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user