Fix merge conflict in text_generation

- Need to update `shared.still_streaming = False` before the final `yield formatted_outputs`, shifted the position of some yields.
This commit is contained in:
Xan 2023-03-12 18:56:35 +11:00
commit b3e10e47c0
10 changed files with 298 additions and 155 deletions

64
.idea/workspace.xml Normal file
View File

@ -0,0 +1,64 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="edbf3935-4476-45aa-aea0-f1e7cbcf4b9a" name="Changes" comment="">
<change afterPath="$PROJECT_DIR$/extensions/llama_prompts/script.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/modules/callbacks.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modules/RWKV.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/RWKV.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modules/chat.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/chat.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modules/shared.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/shared.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/modules/stopping_criteria.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/modules/text_generation.py" beforeDir="false" afterPath="$PROJECT_DIR$/modules/text_generation.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/requirements.txt" beforeDir="false" afterPath="$PROJECT_DIR$/requirements.txt" afterDir="false" />
<change beforePath="$PROJECT_DIR$/server.py" beforeDir="false" afterPath="$PROJECT_DIR$/server.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
<component name="ProjectId" id="2MtdH03e5QdbSP16WYYfDkhyFUC" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
<component name="ProjectViewState">
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"ASKED_SHARE_PROJECT_CONFIGURATION_FILES": "true",
"RunOnceActivity.OpenProjectViewOnStart": "true",
"RunOnceActivity.ShowReadmeOnStart": "true"
}
}]]></component>
<component name="RunManager">
<configuration default="true" type="JetRunConfigurationType">
<module name="text-generation-webui" />
<method v="2">
<option name="Make" enabled="true" />
</method>
</configuration>
<configuration default="true" type="KotlinStandaloneScriptRunConfigurationType">
<module name="text-generation-webui" />
<option name="filePath" />
<method v="2">
<option name="Make" enabled="true" />
</method>
</configuration>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="edbf3935-4476-45aa-aea0-f1e7cbcf4b9a" name="Changes" comment="" />
<created>1678590722207</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1678590722207</updated>
</task>
<servers />
</component>
</project>

View File

@ -0,0 +1,18 @@
import gradio as gr
import modules.shared as shared
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name):
if name == 'None':
return ''
else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui():
if not shared.args.chat or share.args.cai_chat:
choices = ['None'] + list(df['Prompt name'])
prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])

View File

@ -7,6 +7,7 @@ import numpy as np
from tokenizers import Tokenizer from tokenizers import Tokenizer
import modules.shared as shared import modules.shared as shared
from modules.callbacks import Iteratorize
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
@ -49,9 +50,9 @@ class RWKVModel:
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, **kwargs):
iterable = Iteratorize(self.generate, kwargs, callback=None) with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context'] reply = kwargs['context']
for token in iterable: for token in generator:
reply += token reply += token
yield reply yield reply
@ -73,38 +74,3 @@ class RWKVTokenizer:
def decode(self, ids): def decode(self, ids):
return self.tokenizer.decode(ids) return self.tokenizer.decode(ids)
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue(maxsize=1)
self.sentinel = object()
self.kwargs = kwargs
def _callback(val):
self.q.put(val)
def gentask():
ret = self.mfunc(callback=_callback, **self.kwargs)
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
Thread(target=gentask).start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

98
modules/callbacks.py Normal file
View File

@ -0,0 +1,98 @@
import gc
from queue import Queue
from threading import Thread
import torch
import transformers
import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
if self.stop_now:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
clear_torch_cache()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()

View File

@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
tmp = f"\n{asker}:" tmp = f"\n{asker}:"
for j in range(1, len(tmp)): for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]: if reply[-j:] == tmp[:j]:
reply = reply[:-j]
substring_found = True substring_found = True
return reply, next_character_found, substring_found return reply, next_character_found, substring_found
@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False shared.stop_everything = False
just_started = True just_started = True
eos_token = '\n' if check else None eos_token = '\n' if check else None
@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
else: else:
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
if not regenerate:
# Display user input and "*is typing...*" imediately
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
# Generate # Generate
reply = '' reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
# Display "*is typing...*" imediately
yield '*Is typing...*'
reply = '' reply = ''
for i in range(chat_generation_attempts): for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat: if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@ -291,7 +299,7 @@ def save_history(timestamp=True):
fname = f"{prefix}persistent.json" fname = f"{prefix}persistent.json"
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w') as f: with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
@ -332,7 +340,7 @@ def load_character(_character, name1, name2):
shared.history['visible'] = [] shared.history['visible'] = []
if _character != 'None': if _character != 'None':
shared.character = _character shared.character = _character
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
name2 = data['char_name'] name2 = data['char_name']
if 'char_persona' in data and data['char_persona'] != '': if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False):
i += 1 i += 1
if tavern: if tavern:
outfile_name = f'TavernAI-{outfile_name}' outfile_name = f'TavernAI-{outfile_name}'
with open(Path(f'characters/{outfile_name}.json'), 'w') as f: with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
f.write(json_file) f.write(json_file)
if img is not None: if img is not None:
img = Image.open(io.BytesIO(img)) img = Image.open(io.BytesIO(img))

View File

@ -91,4 +91,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,32 +0,0 @@
'''
This code was copied from
https://github.com/PygmalionAI/gradio-ui/
'''
import torch
import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,
starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
def __call__(self, input_ids: torch.LongTensor,
_scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
# Can't unfold, output is still too tiny. Skip.
if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
continue
for window in trimmed_sample.unfold(
0, self.sentinel_token_ids.shape[-1], 1):
if torch.all(torch.eq(self.sentinel_token_ids, window)):
return True
return False

View File

@ -5,13 +5,13 @@ import time
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from tqdm import tqdm
import modules.shared as shared import modules.shared as shared
from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
def get_max_prompt_length(tokens): def get_max_prompt_length(tokens):
@ -92,6 +92,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if shared.is_RWKV: if shared.is_RWKV:
try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
@ -101,9 +102,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# No need to generate 8 tokens at a time. # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
finally:
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds.") output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return return
original_question = question original_question = question
@ -113,23 +116,19 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n") print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens) input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None: if stopping_string is not None:
# The stopping_criteria code below was copied from # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False) t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
_SentinelTokenStoppingCriteria(
sentinel_token_ids=t,
starting_idx=len(input_ids[0])
)
])
else:
stopping_criteria_list = None
if not shared.args.flexgen: if not shared.args.flexgen:
generate_params = [ generate_params = [
f"max_new_tokens=max_new_tokens",
f"eos_token_id={n}", f"eos_token_id={n}",
f"stopping_criteria=stopping_criteria_list", f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}", f"do_sample={do_sample}",
@ -147,24 +146,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
] ]
else: else:
generate_params = [ generate_params = [
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
f"do_sample={do_sample}", f"do_sample={do_sample}",
f"temperature={temperature}", f"temperature={temperature}",
f"stop={n}", f"stop={n}",
] ]
if shared.args.deepspeed: if shared.args.deepspeed:
generate_params.append("synced_gpus=True") generate_params.append("synced_gpus=True")
if shared.args.no_stream:
generate_params.append("max_new_tokens=max_new_tokens")
else:
generate_params.append("max_new_tokens=8")
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds") generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids") generate_params.insert(0, "inputs=filler_input_ids")
else: else:
generate_params.insert(0, "input_ids") generate_params.insert(0, "inputs=input_ids")
# Generate the entire reply at once try:
# Generate the entire reply at once.
if shared.args.no_stream: if shared.args.no_stream:
with torch.no_grad(): with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
@ -175,41 +172,65 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
# Generate the reply 8 tokens at a time # Stream the reply 1 token at a time.
else: # This is based on the trick of using 'stopping_criteria' to create an iterator.
yield formatted_outputs(original_question, shared.model_name) elif not shared.args.flexgen:
shared.still_streaming = True
for i in tqdm(range(max_new_tokens//8+1)):
clear_torch_cache()
def generate_with_callback(callback=None, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
shared.still_streaming = True
yield formatted_outputs(original_question, shared.model_name)
with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
for output in generator:
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output) reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
if not shared.args.flexgen:
if output[-1] == n: if output[-1] == n:
break break
input_ids = torch.reshape(output, (1, output.shape[0]))
else:
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
input_ids = np.reshape(output, (1, output.shape[0]))
#Mid-stream yield, ran if no breaks
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
shared.still_streaming = False
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
shared.still_streaming = True
for i in range(max_new_tokens//8+1):
clear_torch_cache()
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
#Stream finished from max tokens or break. Do final yield.
shared.still_streaming = False shared.still_streaming = False
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
finally:
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
return

View File

@ -3,6 +3,7 @@ bitsandbytes==0.37.0
flexgen==0.1.7 flexgen==0.1.7
gradio==3.18.0 gradio==3.18.0
numpy numpy
requests
rwkv==0.1.0 rwkv==0.1.0
safetensors==0.2.8 safetensors==0.2.8
sentencepiece sentencepiece

View File

@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat:
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
@ -372,9 +370,9 @@ else:
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
# I think that I will need this later # I think that I will need this later
while True: while True: