mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Fix merge conflict in text_generation
- Need to update `shared.still_streaming = False` before the final `yield formatted_outputs`, shifted the position of some yields.
This commit is contained in:
commit
b3e10e47c0
64
.idea/workspace.xml
Normal file
64
.idea/workspace.xml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ChangeListManager">
|
||||||
|
<list default="true" id="edbf3935-4476-45aa-aea0-f1e7cbcf4b9a" name="Changes" comment="">
|
||||||
|
<change afterPath="$PROJECT_DIR$/extensions/llama_prompts/script.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/modules/callbacks.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/modules/RWKV.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/RWKV.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/modules/chat.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/chat.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/modules/shared.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/shared.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/modules/stopping_criteria.py" beforeDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/modules/text_generation.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/text_generation.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/requirements.txt" beforeDir="false" afterPath="$PROJECT_DIR$/requirements.txt" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/server.py" beforeDir="false" afterPath="$PROJECT_DIR$/server.py" afterDir="false" />
|
||||||
|
</list>
|
||||||
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||||
|
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||||
|
</component>
|
||||||
|
<component name="Git.Settings">
|
||||||
|
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||||
|
</component>
|
||||||
|
<component name="MarkdownSettingsMigration">
|
||||||
|
<option name="stateVersion" value="1" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectId" id="2MtdH03e5QdbSP16WYYfDkhyFUC" />
|
||||||
|
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
||||||
|
<component name="ProjectViewState">
|
||||||
|
<option name="showLibraryContents" value="true" />
|
||||||
|
</component>
|
||||||
|
<component name="PropertiesComponent"><![CDATA[{
|
||||||
|
"keyToString": {
|
||||||
|
"ASKED_SHARE_PROJECT_CONFIGURATION_FILES": "true",
|
||||||
|
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||||
|
"RunOnceActivity.ShowReadmeOnStart": "true"
|
||||||
|
}
|
||||||
|
}]]></component>
|
||||||
|
<component name="RunManager">
|
||||||
|
<configuration default="true" type="JetRunConfigurationType">
|
||||||
|
<module name="text-generation-webui" />
|
||||||
|
<method v="2">
|
||||||
|
<option name="Make" enabled="true" />
|
||||||
|
</method>
|
||||||
|
</configuration>
|
||||||
|
<configuration default="true" type="KotlinStandaloneScriptRunConfigurationType">
|
||||||
|
<module name="text-generation-webui" />
|
||||||
|
<option name="filePath" />
|
||||||
|
<method v="2">
|
||||||
|
<option name="Make" enabled="true" />
|
||||||
|
</method>
|
||||||
|
</configuration>
|
||||||
|
</component>
|
||||||
|
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
||||||
|
<component name="TaskManager">
|
||||||
|
<task active="true" id="Default" summary="Default task">
|
||||||
|
<changelist id="edbf3935-4476-45aa-aea0-f1e7cbcf4b9a" name="Changes" comment="" />
|
||||||
|
<created>1678590722207</created>
|
||||||
|
<option name="number" value="Default" />
|
||||||
|
<option name="presentableId" value="Default" />
|
||||||
|
<updated>1678590722207</updated>
|
||||||
|
</task>
|
||||||
|
<servers />
|
||||||
|
</component>
|
||||||
|
</project>
|
18
extensions/llama_prompts/script.py
Normal file
18
extensions/llama_prompts/script.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import modules.shared as shared
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||||
|
|
||||||
|
def get_prompt_by_name(name):
|
||||||
|
if name == 'None':
|
||||||
|
return ''
|
||||||
|
else:
|
||||||
|
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||||
|
|
||||||
|
def ui():
|
||||||
|
if not shared.args.chat or share.args.cai_chat:
|
||||||
|
choices = ['None'] + list(df['Prompt name'])
|
||||||
|
|
||||||
|
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
|
||||||
|
prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])
|
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.callbacks import Iteratorize
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
|
||||||
@ -49,11 +50,11 @@ class RWKVModel:
|
|||||||
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||||
|
|
||||||
def generate_with_streaming(self, **kwargs):
|
def generate_with_streaming(self, **kwargs):
|
||||||
iterable = Iteratorize(self.generate, kwargs, callback=None)
|
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||||
reply = kwargs['context']
|
reply = kwargs['context']
|
||||||
for token in iterable:
|
for token in generator:
|
||||||
reply += token
|
reply += token
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
class RWKVTokenizer:
|
class RWKVTokenizer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -73,38 +74,3 @@ class RWKVTokenizer:
|
|||||||
|
|
||||||
def decode(self, ids):
|
def decode(self, ids):
|
||||||
return self.tokenizer.decode(ids)
|
return self.tokenizer.decode(ids)
|
||||||
|
|
||||||
class Iteratorize:
|
|
||||||
|
|
||||||
"""
|
|
||||||
Transforms a function that takes a callback
|
|
||||||
into a lazy iterator (generator).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, func, kwargs={}, callback=None):
|
|
||||||
self.mfunc=func
|
|
||||||
self.c_callback=callback
|
|
||||||
self.q = Queue(maxsize=1)
|
|
||||||
self.sentinel = object()
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def _callback(val):
|
|
||||||
self.q.put(val)
|
|
||||||
|
|
||||||
def gentask():
|
|
||||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
|
||||||
self.q.put(self.sentinel)
|
|
||||||
if self.c_callback:
|
|
||||||
self.c_callback(ret)
|
|
||||||
|
|
||||||
Thread(target=gentask).start()
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
obj = self.q.get(True,None)
|
|
||||||
if obj is self.sentinel:
|
|
||||||
raise StopIteration
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
98
modules/callbacks.py
Normal file
98
modules/callbacks.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import gc
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||||
|
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||||
|
|
||||||
|
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
||||||
|
starting_idx: int):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
self.sentinel_token_ids = sentinel_token_ids
|
||||||
|
self.starting_idx = starting_idx
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor,
|
||||||
|
_scores: torch.FloatTensor) -> bool:
|
||||||
|
for sample in input_ids:
|
||||||
|
trimmed_sample = sample[self.starting_idx:]
|
||||||
|
# Can't unfold, output is still too tiny. Skip.
|
||||||
|
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for window in trimmed_sample.unfold(
|
||||||
|
0, self.sentinel_token_ids.shape[-1], 1):
|
||||||
|
if torch.all(torch.eq(self.sentinel_token_ids, window)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
class Stream(transformers.StoppingCriteria):
|
||||||
|
def __init__(self, callback_func=None):
|
||||||
|
self.callback_func = callback_func
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores) -> bool:
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(input_ids[0])
|
||||||
|
return False
|
||||||
|
|
||||||
|
class Iteratorize:
|
||||||
|
|
||||||
|
"""
|
||||||
|
Transforms a function that takes a callback
|
||||||
|
into a lazy iterator (generator).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func, kwargs={}, callback=None):
|
||||||
|
self.mfunc=func
|
||||||
|
self.c_callback=callback
|
||||||
|
self.q = Queue()
|
||||||
|
self.sentinel = object()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.stop_now = False
|
||||||
|
|
||||||
|
def _callback(val):
|
||||||
|
if self.stop_now:
|
||||||
|
raise ValueError
|
||||||
|
self.q.put(val)
|
||||||
|
|
||||||
|
def gentask():
|
||||||
|
try:
|
||||||
|
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
clear_torch_cache()
|
||||||
|
self.q.put(self.sentinel)
|
||||||
|
if self.c_callback:
|
||||||
|
self.c_callback(ret)
|
||||||
|
|
||||||
|
self.thread = Thread(target=gentask)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
obj = self.q.get(True,None)
|
||||||
|
if obj is self.sentinel:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop_now = True
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
def clear_torch_cache():
|
||||||
|
gc.collect()
|
||||||
|
if not shared.args.cpu:
|
||||||
|
torch.cuda.empty_cache()
|
@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
|||||||
tmp = f"\n{asker}:"
|
tmp = f"\n{asker}:"
|
||||||
for j in range(1, len(tmp)):
|
for j in range(1, len(tmp)):
|
||||||
if reply[-j:] == tmp[:j]:
|
if reply[-j:] == tmp[:j]:
|
||||||
|
reply = reply[:-j]
|
||||||
substring_found = True
|
substring_found = True
|
||||||
|
|
||||||
return reply, next_character_found, substring_found
|
return reply, next_character_found, substring_found
|
||||||
@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
|||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
just_started = True
|
just_started = True
|
||||||
eos_token = '\n' if check else None
|
eos_token = '\n' if check else None
|
||||||
@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
|||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||||
|
|
||||||
|
if not regenerate:
|
||||||
|
# Display user input and "*is typing...*" imediately
|
||||||
|
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
reply = ''
|
reply = ''
|
||||||
for i in range(chat_generation_attempts):
|
for i in range(chat_generation_attempts):
|
||||||
@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
|||||||
|
|
||||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
||||||
|
|
||||||
|
# Display "*is typing...*" imediately
|
||||||
|
yield '*Is typing...*'
|
||||||
|
|
||||||
reply = ''
|
reply = ''
|
||||||
for i in range(chat_generation_attempts):
|
for i in range(chat_generation_attempts):
|
||||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||||
@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
last_visible = shared.history['visible'].pop()
|
last_visible = shared.history['visible'].pop()
|
||||||
last_internal = shared.history['internal'].pop()
|
last_internal = shared.history['internal'].pop()
|
||||||
|
|
||||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
|
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||||
if shared.args.cai_chat:
|
if shared.args.cai_chat:
|
||||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||||
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||||
@ -291,7 +299,7 @@ def save_history(timestamp=True):
|
|||||||
fname = f"{prefix}persistent.json"
|
fname = f"{prefix}persistent.json"
|
||||||
if not Path('logs').exists():
|
if not Path('logs').exists():
|
||||||
Path('logs').mkdir()
|
Path('logs').mkdir()
|
||||||
with open(Path(f'logs/{fname}'), 'w') as f:
|
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||||
return Path(f'logs/{fname}')
|
return Path(f'logs/{fname}')
|
||||||
|
|
||||||
@ -332,7 +340,7 @@ def load_character(_character, name1, name2):
|
|||||||
shared.history['visible'] = []
|
shared.history['visible'] = []
|
||||||
if _character != 'None':
|
if _character != 'None':
|
||||||
shared.character = _character
|
shared.character = _character
|
||||||
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
|
data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
|
||||||
name2 = data['char_name']
|
name2 = data['char_name']
|
||||||
if 'char_persona' in data and data['char_persona'] != '':
|
if 'char_persona' in data and data['char_persona'] != '':
|
||||||
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
|
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
|
||||||
@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False):
|
|||||||
i += 1
|
i += 1
|
||||||
if tavern:
|
if tavern:
|
||||||
outfile_name = f'TavernAI-{outfile_name}'
|
outfile_name = f'TavernAI-{outfile_name}'
|
||||||
with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
|
with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
|
||||||
f.write(json_file)
|
f.write(json_file)
|
||||||
if img is not None:
|
if img is not None:
|
||||||
img = Image.open(io.BytesIO(img))
|
img = Image.open(io.BytesIO(img))
|
||||||
|
@ -91,4 +91,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
|
|||||||
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
|
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
|
||||||
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
||||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
|
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
'''
|
|
||||||
This code was copied from
|
|
||||||
|
|
||||||
https://github.com/PygmalionAI/gradio-ui/
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|
||||||
|
|
||||||
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
|
||||||
starting_idx: int):
|
|
||||||
transformers.StoppingCriteria.__init__(self)
|
|
||||||
self.sentinel_token_ids = sentinel_token_ids
|
|
||||||
self.starting_idx = starting_idx
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor,
|
|
||||||
_scores: torch.FloatTensor) -> bool:
|
|
||||||
for sample in input_ids:
|
|
||||||
trimmed_sample = sample[self.starting_idx:]
|
|
||||||
# Can't unfold, output is still too tiny. Skip.
|
|
||||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for window in trimmed_sample.unfold(
|
|
||||||
0, self.sentinel_token_ids.shape[-1], 1):
|
|
||||||
if torch.all(torch.eq(self.sentinel_token_ids, window)):
|
|
||||||
return True
|
|
||||||
return False
|
|
@ -5,13 +5,13 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.callbacks import (Iteratorize, Stream,
|
||||||
|
_SentinelTokenStoppingCriteria)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.models import local_rank
|
from modules.models import local_rank
|
||||||
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(tokens):
|
def get_max_prompt_length(tokens):
|
||||||
@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
# These models are not part of Hugging Face, so we handle them
|
# These models are not part of Hugging Face, so we handle them
|
||||||
# separately and terminate the function call earlier
|
# separately and terminate the function call earlier
|
||||||
if shared.is_RWKV:
|
if shared.is_RWKV:
|
||||||
if shared.args.no_stream:
|
try:
|
||||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
if shared.args.no_stream:
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||||
else:
|
|
||||||
yield formatted_outputs(question, shared.model_name)
|
|
||||||
# RWKV has proper streaming, which is very nice.
|
|
||||||
# No need to generate 8 tokens at a time.
|
|
||||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
else:
|
||||||
t1 = time.time()
|
yield formatted_outputs(question, shared.model_name)
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds.")
|
# RWKV has proper streaming, which is very nice.
|
||||||
return
|
# No need to generate 8 tokens at a time.
|
||||||
|
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
finally:
|
||||||
|
t1 = time.time()
|
||||||
|
output = encode(reply)[0]
|
||||||
|
input_ids = encode(question)
|
||||||
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
|
||||||
|
return
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not (shared.args.chat or shared.args.cai_chat):
|
||||||
@ -113,23 +116,19 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
print(f"\n\n{question}\n--------------------\n")
|
print(f"\n\n{question}\n--------------------\n")
|
||||||
|
|
||||||
input_ids = encode(question, max_new_tokens)
|
input_ids = encode(question, max_new_tokens)
|
||||||
|
original_input_ids = input_ids
|
||||||
|
output = input_ids[0]
|
||||||
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
||||||
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
|
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
if stopping_string is not None:
|
if stopping_string is not None:
|
||||||
# The stopping_criteria code below was copied from
|
# Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
|
||||||
t = encode(stopping_string, 0, add_special_tokens=False)
|
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList([
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||||
_SentinelTokenStoppingCriteria(
|
|
||||||
sentinel_token_ids=t,
|
|
||||||
starting_idx=len(input_ids[0])
|
|
||||||
)
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
stopping_criteria_list = None
|
|
||||||
|
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
generate_params = [
|
generate_params = [
|
||||||
|
f"max_new_tokens=max_new_tokens",
|
||||||
f"eos_token_id={n}",
|
f"eos_token_id={n}",
|
||||||
f"stopping_criteria=stopping_criteria_list",
|
f"stopping_criteria=stopping_criteria_list",
|
||||||
f"do_sample={do_sample}",
|
f"do_sample={do_sample}",
|
||||||
@ -147,45 +146,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
generate_params = [
|
generate_params = [
|
||||||
|
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
|
||||||
f"do_sample={do_sample}",
|
f"do_sample={do_sample}",
|
||||||
f"temperature={temperature}",
|
f"temperature={temperature}",
|
||||||
f"stop={n}",
|
f"stop={n}",
|
||||||
]
|
]
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
generate_params.append("synced_gpus=True")
|
generate_params.append("synced_gpus=True")
|
||||||
if shared.args.no_stream:
|
|
||||||
generate_params.append("max_new_tokens=max_new_tokens")
|
|
||||||
else:
|
|
||||||
generate_params.append("max_new_tokens=8")
|
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
||||||
generate_params.insert(0, "filler_input_ids")
|
generate_params.insert(0, "inputs=filler_input_ids")
|
||||||
else:
|
else:
|
||||||
generate_params.insert(0, "input_ids")
|
generate_params.insert(0, "inputs=input_ids")
|
||||||
|
|
||||||
# Generate the entire reply at once
|
|
||||||
if shared.args.no_stream:
|
|
||||||
with torch.no_grad():
|
|
||||||
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
|
|
||||||
if shared.soft_prompt:
|
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
|
||||||
|
|
||||||
reply = decode(output)
|
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
|
||||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
|
||||||
|
|
||||||
# Generate the reply 8 tokens at a time
|
|
||||||
else:
|
|
||||||
yield formatted_outputs(original_question, shared.model_name)
|
|
||||||
shared.still_streaming = True
|
|
||||||
for i in tqdm(range(max_new_tokens//8+1)):
|
|
||||||
clear_torch_cache()
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate the entire reply at once.
|
||||||
|
if shared.args.no_stream:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
|
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
@ -194,22 +171,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
reply = decode(output)
|
reply = decode(output)
|
||||||
if not (shared.args.chat or shared.args.cai_chat):
|
if not (shared.args.chat or shared.args.cai_chat):
|
||||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
|
|
||||||
if not shared.args.flexgen:
|
|
||||||
if output[-1] == n:
|
|
||||||
break
|
|
||||||
input_ids = torch.reshape(output, (1, output.shape[0]))
|
|
||||||
else:
|
|
||||||
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
|
|
||||||
break
|
|
||||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
|
||||||
|
|
||||||
#Mid-stream yield, ran if no breaks
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
if shared.soft_prompt:
|
# Stream the reply 1 token at a time.
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||||
|
elif not shared.args.flexgen:
|
||||||
#Stream finished from max tokens or break. Do final yield.
|
|
||||||
shared.still_streaming = False
|
def generate_with_callback(callback=None, **kwargs):
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
|
clear_torch_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
shared.model.generate(**kwargs)
|
||||||
|
|
||||||
|
def generate_with_streaming(**kwargs):
|
||||||
|
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||||
|
|
||||||
|
shared.still_streaming = True
|
||||||
|
yield formatted_outputs(original_question, shared.model_name)
|
||||||
|
with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
|
||||||
|
for output in generator:
|
||||||
|
if shared.soft_prompt:
|
||||||
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
reply = decode(output)
|
||||||
|
|
||||||
|
if not (shared.args.chat or shared.args.cai_chat):
|
||||||
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
|
|
||||||
|
if output[-1] == n:
|
||||||
|
break
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
shared.still_streaming = False
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
|
else:
|
||||||
|
shared.still_streaming = True
|
||||||
|
for i in range(max_new_tokens//8+1):
|
||||||
|
clear_torch_cache()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
|
||||||
|
if shared.soft_prompt:
|
||||||
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
|
reply = decode(output)
|
||||||
|
|
||||||
|
if not (shared.args.chat or shared.args.cai_chat):
|
||||||
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
|
|
||||||
|
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
|
||||||
|
break
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||||
|
if shared.soft_prompt:
|
||||||
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
|
|
||||||
|
shared.still_streaming = False
|
||||||
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
t1 = time.time()
|
||||||
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
|
||||||
|
return
|
||||||
|
@ -3,6 +3,7 @@ bitsandbytes==0.37.0
|
|||||||
flexgen==0.1.7
|
flexgen==0.1.7
|
||||||
gradio==3.18.0
|
gradio==3.18.0
|
||||||
numpy
|
numpy
|
||||||
|
requests
|
||||||
rwkv==0.1.0
|
rwkv==0.1.0
|
||||||
safetensors==0.2.8
|
safetensors==0.2.8
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
16
server.py
16
server.py
@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
|
|||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
|
|
||||||
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
|
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
|
|||||||
|
|
||||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
|
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
|
||||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
|
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
|
||||||
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat:
|
|||||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
||||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||||
|
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||||
|
|
||||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||||
@ -372,9 +370,9 @@ else:
|
|||||||
|
|
||||||
shared.gradio['interface'].queue()
|
shared.gradio['interface'].queue()
|
||||||
if shared.args.listen:
|
if shared.args.listen:
|
||||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port)
|
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||||
else:
|
else:
|
||||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
|
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||||
|
|
||||||
# I think that I will need this later
|
# I think that I will need this later
|
||||||
while True:
|
while True:
|
||||||
|
Loading…
Reference in New Issue
Block a user