diff --git a/modules/callbacks.py b/modules/callbacks.py new file mode 100644 index 00000000..15674b8a --- /dev/null +++ b/modules/callbacks.py @@ -0,0 +1,75 @@ +from queue import Queue +from threading import Thread + +import torch +import transformers + +import modules.shared as shared + + +# Copied from https://github.com/PygmalionAI/gradio-ui/ +class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): + + def __init__(self, sentinel_token_ids: torch.LongTensor, + starting_idx: int): + transformers.StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, + _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: + continue + + for window in trimmed_sample.unfold( + 0, self.sentinel_token_ids.shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids, window)): + return True + return False + +class Stream(transformers.StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + +class Iteratorize: + + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc=func + self.c_callback=callback + self.q = Queue(maxsize=1) + self.sentinel = object() + self.kwargs = kwargs + + def _callback(val): + self.q.put(val) + + def gentask(): + ret = self.mfunc(callback=_callback, **self.kwargs) + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + Thread(target=gentask).start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True,None) + if obj is self.sentinel: + raise StopIteration + else: + return obj diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py deleted file mode 100644 index 44a631b3..00000000 --- a/modules/stopping_criteria.py +++ /dev/null @@ -1,32 +0,0 @@ -''' -This code was copied from - -https://github.com/PygmalionAI/gradio-ui/ - -''' - -import torch -import transformers - - -class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): - - def __init__(self, sentinel_token_ids: torch.LongTensor, - starting_idx: int): - transformers.StoppingCriteria.__init__(self) - self.sentinel_token_ids = sentinel_token_ids - self.starting_idx = starting_idx - - def __call__(self, input_ids: torch.LongTensor, - _scores: torch.FloatTensor) -> bool: - for sample in input_ids: - trimmed_sample = sample[self.starting_idx:] - # Can't unfold, output is still too tiny. Skip. - if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: - continue - - for window in trimmed_sample.unfold( - 0, self.sentinel_token_ids.shape[-1], 1): - if torch.all(torch.eq(self.sentinel_token_ids, window)): - return True - return False diff --git a/modules/text_generation.py b/modules/text_generation.py index 4af53273..436afbeb 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -5,13 +5,13 @@ import time import numpy as np import torch import transformers -from tqdm import tqdm import modules.shared as shared +from modules.callbacks import (Iteratorize, Stream, + _SentinelTokenStoppingCriteria) from modules.extensions import apply_extensions from modules.html_generator import generate_4chan_html, generate_basic_html from modules.models import local_rank -from modules.stopping_criteria import _SentinelTokenStoppingCriteria def get_max_prompt_length(tokens): @@ -103,7 +103,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds.") + output = encode(reply)[0] + input_ids = encode(question) + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") return original_question = question @@ -113,6 +115,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi print(f"\n\n{question}\n--------------------\n") input_ids = encode(question, max_new_tokens) + original_input_ids = input_ids cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) if stopping_string is not None: @@ -126,10 +129,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi ) ]) else: - stopping_criteria_list = None + stopping_criteria_list = [] if not shared.args.flexgen: generate_params = [ + f"max_new_tokens=max_new_tokens", f"eos_token_id={n}", f"stopping_criteria=stopping_criteria_list", f"do_sample={do_sample}", @@ -147,24 +151,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi ] else: generate_params = [ + f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}", f"do_sample={do_sample}", f"temperature={temperature}", f"stop={n}", ] if shared.args.deepspeed: generate_params.append("synced_gpus=True") - if shared.args.no_stream: - generate_params.append("max_new_tokens=max_new_tokens") - else: - generate_params.append("max_new_tokens=8") if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) generate_params.insert(0, "inputs_embeds=inputs_embeds") - generate_params.insert(0, "filler_input_ids") + generate_params.insert(0, "inputs=filler_input_ids") else: - generate_params.insert(0, "input_ids") + generate_params.insert(0, "inputs=input_ids") - # Generate the entire reply at once + # Generate the entire reply at once. if shared.args.no_stream: with torch.no_grad(): output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] @@ -175,18 +176,45 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if not (shared.args.chat or shared.args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output") - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") yield formatted_outputs(reply, shared.model_name) - # Generate the reply 8 tokens at a time - else: + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator. + elif not shared.args.flexgen: + + def generate_with_callback(callback=None, **kwargs): + if 'stopping_criteria' not in kwargs: + kwargs['stopping_criteria'] = [] + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + shared.model.generate(**kwargs)[0] + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + yield formatted_outputs(original_question, shared.model_name) - for i in tqdm(range(max_new_tokens//8+1)): + for output in eval(f"generate_with_streaming({', '.join(generate_params)})"): + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + + reply = decode(output) + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + yield formatted_outputs(reply, shared.model_name) + + if not shared.args.flexgen: + if output[-1] == n: + break + else: + if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): + break + + # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' + else: + for i in range(max_new_tokens//8+1): clear_torch_cache() with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] + output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -206,3 +234,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") + return diff --git a/server.py b/server.py index 9f584ba3..42897b0b 100644 --- a/server.py +++ b/server.py @@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply -if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: - print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n') - # Loading custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists():