From ab50f80542788dd7fa21b20ac91fddc8c9766c23 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 02:46:35 -0300
Subject: [PATCH 01/11] New text streaming method (much faster)
---
modules/callbacks.py | 75 ++++++++++++++++++++++++++++++++++++
modules/stopping_criteria.py | 32 ---------------
modules/text_generation.py | 66 +++++++++++++++++++++++--------
server.py | 3 --
4 files changed, 124 insertions(+), 52 deletions(-)
create mode 100644 modules/callbacks.py
delete mode 100644 modules/stopping_criteria.py
diff --git a/modules/callbacks.py b/modules/callbacks.py
new file mode 100644
index 00000000..15674b8a
--- /dev/null
+++ b/modules/callbacks.py
@@ -0,0 +1,75 @@
+from queue import Queue
+from threading import Thread
+
+import torch
+import transformers
+
+import modules.shared as shared
+
+
+# Copied from https://github.com/PygmalionAI/gradio-ui/
+class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
+
+ def __init__(self, sentinel_token_ids: torch.LongTensor,
+ starting_idx: int):
+ transformers.StoppingCriteria.__init__(self)
+ self.sentinel_token_ids = sentinel_token_ids
+ self.starting_idx = starting_idx
+
+ def __call__(self, input_ids: torch.LongTensor,
+ _scores: torch.FloatTensor) -> bool:
+ for sample in input_ids:
+ trimmed_sample = sample[self.starting_idx:]
+ # Can't unfold, output is still too tiny. Skip.
+ if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
+ continue
+
+ for window in trimmed_sample.unfold(
+ 0, self.sentinel_token_ids.shape[-1], 1):
+ if torch.all(torch.eq(self.sentinel_token_ids, window)):
+ return True
+ return False
+
+class Stream(transformers.StoppingCriteria):
+ def __init__(self, callback_func=None):
+ self.callback_func = callback_func
+
+ def __call__(self, input_ids, scores) -> bool:
+ if self.callback_func is not None:
+ self.callback_func(input_ids[0])
+ return False
+
+class Iteratorize:
+
+ """
+ Transforms a function that takes a callback
+ into a lazy iterator (generator).
+ """
+
+ def __init__(self, func, kwargs={}, callback=None):
+ self.mfunc=func
+ self.c_callback=callback
+ self.q = Queue(maxsize=1)
+ self.sentinel = object()
+ self.kwargs = kwargs
+
+ def _callback(val):
+ self.q.put(val)
+
+ def gentask():
+ ret = self.mfunc(callback=_callback, **self.kwargs)
+ self.q.put(self.sentinel)
+ if self.c_callback:
+ self.c_callback(ret)
+
+ Thread(target=gentask).start()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ obj = self.q.get(True,None)
+ if obj is self.sentinel:
+ raise StopIteration
+ else:
+ return obj
diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py
deleted file mode 100644
index 44a631b3..00000000
--- a/modules/stopping_criteria.py
+++ /dev/null
@@ -1,32 +0,0 @@
-'''
-This code was copied from
-
-https://github.com/PygmalionAI/gradio-ui/
-
-'''
-
-import torch
-import transformers
-
-
-class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
-
- def __init__(self, sentinel_token_ids: torch.LongTensor,
- starting_idx: int):
- transformers.StoppingCriteria.__init__(self)
- self.sentinel_token_ids = sentinel_token_ids
- self.starting_idx = starting_idx
-
- def __call__(self, input_ids: torch.LongTensor,
- _scores: torch.FloatTensor) -> bool:
- for sample in input_ids:
- trimmed_sample = sample[self.starting_idx:]
- # Can't unfold, output is still too tiny. Skip.
- if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
- continue
-
- for window in trimmed_sample.unfold(
- 0, self.sentinel_token_ids.shape[-1], 1):
- if torch.all(torch.eq(self.sentinel_token_ids, window)):
- return True
- return False
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 4af53273..436afbeb 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -5,13 +5,13 @@ import time
import numpy as np
import torch
import transformers
-from tqdm import tqdm
import modules.shared as shared
+from modules.callbacks import (Iteratorize, Stream,
+ _SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank
-from modules.stopping_criteria import _SentinelTokenStoppingCriteria
def get_max_prompt_length(tokens):
@@ -103,7 +103,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(reply, shared.model_name)
t1 = time.time()
- print(f"Output generated in {(t1-t0):.2f} seconds.")
+ output = encode(reply)[0]
+ input_ids = encode(question)
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
original_question = question
@@ -113,6 +115,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens)
+ original_input_ids = input_ids
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
if stopping_string is not None:
@@ -126,10 +129,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
)
])
else:
- stopping_criteria_list = None
+ stopping_criteria_list = []
if not shared.args.flexgen:
generate_params = [
+ f"max_new_tokens=max_new_tokens",
f"eos_token_id={n}",
f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}",
@@ -147,24 +151,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
]
else:
generate_params = [
+ f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
f"do_sample={do_sample}",
f"temperature={temperature}",
f"stop={n}",
]
if shared.args.deepspeed:
generate_params.append("synced_gpus=True")
- if shared.args.no_stream:
- generate_params.append("max_new_tokens=max_new_tokens")
- else:
- generate_params.append("max_new_tokens=8")
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.insert(0, "inputs_embeds=inputs_embeds")
- generate_params.insert(0, "filler_input_ids")
+ generate_params.insert(0, "inputs=filler_input_ids")
else:
- generate_params.insert(0, "input_ids")
+ generate_params.insert(0, "inputs=input_ids")
- # Generate the entire reply at once
+ # Generate the entire reply at once.
if shared.args.no_stream:
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
@@ -175,18 +176,45 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
- t1 = time.time()
- print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
yield formatted_outputs(reply, shared.model_name)
- # Generate the reply 8 tokens at a time
- else:
+ # Stream the reply 1 token at a time.
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
+ elif not shared.args.flexgen:
+
+ def generate_with_callback(callback=None, **kwargs):
+ if 'stopping_criteria' not in kwargs:
+ kwargs['stopping_criteria'] = []
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+ shared.model.generate(**kwargs)[0]
+
+ def generate_with_streaming(**kwargs):
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
+
yield formatted_outputs(original_question, shared.model_name)
- for i in tqdm(range(max_new_tokens//8+1)):
+ for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ reply = decode(output)
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
+ yield formatted_outputs(reply, shared.model_name)
+
+ if not shared.args.flexgen:
+ if output[-1] == n:
+ break
+ else:
+ if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
+ break
+
+ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+ else:
+ for i in range(max_new_tokens//8+1):
clear_torch_cache()
with torch.no_grad():
- output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
+ output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
@@ -206,3 +234,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+
+ t1 = time.time()
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
+ return
diff --git a/server.py b/server.py
index 9f584ba3..42897b0b 100644
--- a/server.py
+++ b/server.py
@@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply
-if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
- print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
-
# Loading custom settings
settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists():
From 0e16c0bacb88ad0f5420fd2aa2c6cfadf38e2579 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 02:50:49 -0300
Subject: [PATCH 02/11] Remove redeclaration of a function
---
modules/RWKV.py | 36 +-----------------------------------
1 file changed, 1 insertion(+), 35 deletions(-)
diff --git a/modules/RWKV.py b/modules/RWKV.py
index b226a195..70deab28 100644
--- a/modules/RWKV.py
+++ b/modules/RWKV.py
@@ -7,6 +7,7 @@ import numpy as np
from tokenizers import Tokenizer
import modules.shared as shared
+from modules.callbacks import Iteratorize
np.set_printoptions(precision=4, suppress=True, linewidth=200)
@@ -73,38 +74,3 @@ class RWKVTokenizer:
def decode(self, ids):
return self.tokenizer.decode(ids)
-
-class Iteratorize:
-
- """
- Transforms a function that takes a callback
- into a lazy iterator (generator).
- """
-
- def __init__(self, func, kwargs={}, callback=None):
- self.mfunc=func
- self.c_callback=callback
- self.q = Queue(maxsize=1)
- self.sentinel = object()
- self.kwargs = kwargs
-
- def _callback(val):
- self.q.put(val)
-
- def gentask():
- ret = self.mfunc(callback=_callback, **self.kwargs)
- self.q.put(self.sentinel)
- if self.c_callback:
- self.c_callback(ret)
-
- Thread(target=gentask).start()
-
- def __iter__(self):
- return self
-
- def __next__(self):
- obj = self.q.get(True,None)
- if obj is self.sentinel:
- raise StopIteration
- else:
- return obj
From 72d539dbff6f946fbbd1d8806361dccbc241f8ec Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 02:54:47 -0300
Subject: [PATCH 03/11] Better separate the FlexGen case
---
modules/text_generation.py | 19 +++++--------------
1 file changed, 5 insertions(+), 14 deletions(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 436afbeb..a8157a76 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -201,12 +201,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
- if not shared.args.flexgen:
- if output[-1] == n:
- break
- else:
- if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
- break
+ if output[-1] == n:
+ break
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
@@ -223,14 +219,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
- if not shared.args.flexgen:
- if output[-1] == n:
- break
- input_ids = torch.reshape(output, (1, output.shape[0]))
- else:
- if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
- break
- input_ids = np.reshape(output, (1, output.shape[0]))
+ if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
+ break
+ input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
From ad2970374adeb58aec1d7748b02a8c82cc524c0a Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 03:00:06 -0300
Subject: [PATCH 04/11] Readability improvements
---
modules/text_generation.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index a8157a76..9477fe41 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -195,8 +195,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-
reply = decode(output)
+
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
@@ -213,16 +213,16 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-
reply = decode(output)
+
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
- input_ids = np.reshape(output, (1, output.shape[0]))
+ input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
From 33fb6aed74ebfd50f12373fcbe2f7c0d285022d3 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 03:08:16 -0300
Subject: [PATCH 05/11] Minor bug fix
---
modules/text_generation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 9477fe41..35617314 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -115,7 +115,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens)
- original_input_ids = input_ids
+ original_input_ids = output = input_ids
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
if stopping_string is not None:
From add9330e5e90e33f3f8bbe0ea42290475deb9998 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 11:26:29 -0300
Subject: [PATCH 06/11] Bug fixes
---
modules/text_generation.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 35617314..8f5ea798 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -115,7 +115,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, max_new_tokens)
- original_input_ids = output = input_ids
+ original_input_ids = input_ids
+ output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
if stopping_string is not None:
@@ -186,7 +187,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if 'stopping_criteria' not in kwargs:
kwargs['stopping_criteria'] = []
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
- shared.model.generate(**kwargs)[0]
+ clear_torch_cache()
+ shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
@@ -208,7 +210,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
else:
for i in range(max_new_tokens//8+1):
clear_torch_cache()
-
with torch.no_grad():
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
From 59b5f7a4b731c528f0fa53d70eb3318d3a1727df Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Wed, 8 Mar 2023 12:13:40 -0300
Subject: [PATCH 07/11] Improve usage of stopping_criteria
---
modules/text_generation.py | 19 ++++++-------------
1 file changed, 6 insertions(+), 13 deletions(-)
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 8f5ea798..6a59f9a7 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -119,18 +119,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
+ stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None:
- # The stopping_criteria code below was copied from
- # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
+ # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False)
- stopping_criteria_list = transformers.StoppingCriteriaList([
- _SentinelTokenStoppingCriteria(
- sentinel_token_ids=t,
- starting_idx=len(input_ids[0])
- )
- ])
- else:
- stopping_criteria_list = []
+ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
if not shared.args.flexgen:
generate_params = [
@@ -184,17 +177,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
elif not shared.args.flexgen:
def generate_with_callback(callback=None, **kwargs):
- if 'stopping_criteria' not in kwargs:
- kwargs['stopping_criteria'] = []
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache()
- shared.model.generate(**kwargs)
+ with torch.no_grad():
+ shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
yield formatted_outputs(original_question, shared.model_name)
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
+ print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3))
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
From 37f0166b2d6b0f2938a5a4c1762479829de1c5be Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sat, 11 Mar 2023 23:14:49 -0300
Subject: [PATCH 08/11] Fix memory leak in new streaming (second attempt)
---
modules/callbacks.py | 5 ++++-
modules/text_generation.py | 1 -
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 15674b8a..05e8fafa 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -49,7 +49,7 @@ class Iteratorize:
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
- self.q = Queue(maxsize=1)
+ self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
@@ -73,3 +73,6 @@ class Iteratorize:
raise StopIteration
else:
return obj
+
+ def __del__(self):
+ pass
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 6a59f9a7..5d01c8cb 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -187,7 +187,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
yield formatted_outputs(original_question, shared.model_name)
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
- print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3))
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
From 0bd54309887f6e7adc7e59d4f8675ed6f322bb81 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 12 Mar 2023 02:04:28 -0300
Subject: [PATCH 09/11] Use 'with' statement to better handle streaming memory
---
modules/RWKV.py | 10 +++++-----
modules/callbacks.py | 27 +++++++++++++++++++++++----
modules/text_generation.py | 19 ++++++++++---------
3 files changed, 38 insertions(+), 18 deletions(-)
diff --git a/modules/RWKV.py b/modules/RWKV.py
index 70deab28..836d31dc 100644
--- a/modules/RWKV.py
+++ b/modules/RWKV.py
@@ -50,11 +50,11 @@ class RWKVModel:
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs):
- iterable = Iteratorize(self.generate, kwargs, callback=None)
- reply = kwargs['context']
- for token in iterable:
- reply += token
- yield reply
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = kwargs['context']
+ for token in generator:
+ reply += token
+ yield reply
class RWKVTokenizer:
def __init__(self):
diff --git a/modules/callbacks.py b/modules/callbacks.py
index 05e8fafa..e0d1c988 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -1,3 +1,4 @@
+import gc
from queue import Queue
from threading import Thread
@@ -6,7 +7,6 @@ import transformers
import modules.shared as shared
-
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
@@ -52,17 +52,24 @@ class Iteratorize:
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
+ self.stop_now = False
def _callback(val):
+ if self.stop_now:
+ raise ValueError
self.q.put(val)
def gentask():
- ret = self.mfunc(callback=_callback, **self.kwargs)
+ try:
+ ret = self.mfunc(callback=_callback, **self.kwargs)
+ except ValueError:
+ pass
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
- Thread(target=gentask).start()
+ self.thread = Thread(target=gentask)
+ self.thread.start()
def __iter__(self):
return self
@@ -75,4 +82,16 @@ class Iteratorize:
return obj
def __del__(self):
- pass
+ clear_torch_cache()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stop_now = True
+ clear_torch_cache()
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 5d01c8cb..7f5aad5e 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -186,17 +186,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
return Iteratorize(generate_with_callback, kwargs, callback=None)
yield formatted_outputs(original_question, shared.model_name)
- for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
- if shared.soft_prompt:
- output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
- reply = decode(output)
+ with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
+ for output in generator:
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+ reply = decode(output)
- if not (shared.args.chat or shared.args.cai_chat):
- reply = original_question + apply_extensions(reply[len(question):], "output")
- yield formatted_outputs(reply, shared.model_name)
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
+ yield formatted_outputs(reply, shared.model_name)
- if output[-1] == n:
- break
+ if output[-1] == n:
+ break
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
From b0e8cb8c889cdadd9779517ba8055114b39357cd Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 12 Mar 2023 02:31:45 -0300
Subject: [PATCH 10/11] Various fixes in chat mode
---
modules/chat.py | 16 +++---
modules/text_generation.py | 102 +++++++++++++++++++------------------
2 files changed, 62 insertions(+), 56 deletions(-)
diff --git a/modules/chat.py b/modules/chat.py
index f40f8299..69d81e94 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -115,14 +115,18 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
visible_text = visible_text.replace('\n', '
')
text = apply_extensions(text, "input")
- if custom_generate_chat_prompt is None:
- prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
- else:
- prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
-
# Generate
reply = ''
for i in range(chat_generation_attempts):
+
+ # The prompt needs to be generated here because, as the reply
+ # grows, it may become necessary to remove more old messages to
+ # fit into the 2048 tokens window.
+ if custom_generate_chat_prompt is None:
+ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]))
+ else:
+ prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]))
+
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply
@@ -156,10 +160,10 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
- prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
reply = ''
for i in range(chat_generation_attempts):
+ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]), impersonate=True)
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
if not substring_found:
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 7f5aad5e..2460df4f 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -159,35 +159,53 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
else:
generate_params.insert(0, "inputs=input_ids")
- # Generate the entire reply at once.
- if shared.args.no_stream:
- with torch.no_grad():
- output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
- if shared.soft_prompt:
- output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-
- reply = decode(output)
- if not (shared.args.chat or shared.args.cai_chat):
- reply = original_question + apply_extensions(reply[len(question):], "output")
-
- yield formatted_outputs(reply, shared.model_name)
-
- # Stream the reply 1 token at a time.
- # This is based on the trick of using 'stopping_criteria' to create an iterator.
- elif not shared.args.flexgen:
-
- def generate_with_callback(callback=None, **kwargs):
- kwargs['stopping_criteria'].append(Stream(callback_func=callback))
- clear_torch_cache()
+ try:
+ # Generate the entire reply at once.
+ if shared.args.no_stream:
with torch.no_grad():
- shared.model.generate(**kwargs)
+ output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
- def generate_with_streaming(**kwargs):
- return Iteratorize(generate_with_callback, kwargs, callback=None)
+ reply = decode(output)
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
- yield formatted_outputs(original_question, shared.model_name)
- with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
- for output in generator:
+ yield formatted_outputs(reply, shared.model_name)
+
+ # Stream the reply 1 token at a time.
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
+ elif not shared.args.flexgen:
+
+ def generate_with_callback(callback=None, **kwargs):
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+ clear_torch_cache()
+ with torch.no_grad():
+ shared.model.generate(**kwargs)
+
+ def generate_with_streaming(**kwargs):
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
+
+ yield formatted_outputs(original_question, shared.model_name)
+ with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
+ for output in generator:
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+ reply = decode(output)
+
+ if not (shared.args.chat or shared.args.cai_chat):
+ reply = original_question + apply_extensions(reply[len(question):], "output")
+ yield formatted_outputs(reply, shared.model_name)
+
+ if output[-1] == n:
+ break
+
+ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+ else:
+ for i in range(max_new_tokens//8+1):
+ clear_torch_cache()
+ with torch.no_grad():
+ output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output)
@@ -196,30 +214,14 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, shared.model_name)
- if output[-1] == n:
+ if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
break
- # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
- else:
- for i in range(max_new_tokens//8+1):
- clear_torch_cache()
- with torch.no_grad():
- output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
- if shared.soft_prompt:
- output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
- reply = decode(output)
+ input_ids = np.reshape(output, (1, output.shape[0]))
+ if shared.soft_prompt:
+ inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
- if not (shared.args.chat or shared.args.cai_chat):
- reply = original_question + apply_extensions(reply[len(question):], "output")
- yield formatted_outputs(reply, shared.model_name)
-
- if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
- break
-
- input_ids = np.reshape(output, (1, output.shape[0]))
- if shared.soft_prompt:
- inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
-
- t1 = time.time()
- print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
- return
+ finally:
+ t1 = time.time()
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
+ return
From 341e13503634a0debb684105f055e09772d16c6e Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Sun, 12 Mar 2023 02:53:08 -0300
Subject: [PATCH 11/11] Various fixes in chat mode
---
modules/callbacks.py | 1 +
modules/chat.py | 16 ++++++----------
modules/text_generation.py | 29 +++++++++++++++--------------
3 files changed, 22 insertions(+), 24 deletions(-)
diff --git a/modules/callbacks.py b/modules/callbacks.py
index e0d1c988..faa4a5e9 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -64,6 +64,7 @@ class Iteratorize:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
+ clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
diff --git a/modules/chat.py b/modules/chat.py
index 69d81e94..f40f8299 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -115,18 +115,14 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
visible_text = visible_text.replace('\n', '
')
text = apply_extensions(text, "input")
+ if custom_generate_chat_prompt is None:
+ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+ else:
+ prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+
# Generate
reply = ''
for i in range(chat_generation_attempts):
-
- # The prompt needs to be generated here because, as the reply
- # grows, it may become necessary to remove more old messages to
- # fit into the 2048 tokens window.
- if custom_generate_chat_prompt is None:
- prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]))
- else:
- prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]))
-
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply
@@ -160,10 +156,10 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
+ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
reply = ''
for i in range(chat_generation_attempts):
- prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]), impersonate=True)
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
if not substring_found:
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 2460df4f..7966e126 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -92,21 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if shared.is_RWKV:
- if shared.args.no_stream:
- reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
- yield formatted_outputs(reply, shared.model_name)
- else:
- yield formatted_outputs(question, shared.model_name)
- # RWKV has proper streaming, which is very nice.
- # No need to generate 8 tokens at a time.
- for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+ try:
+ if shared.args.no_stream:
+ reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
yield formatted_outputs(reply, shared.model_name)
-
- t1 = time.time()
- output = encode(reply)[0]
- input_ids = encode(question)
- print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
- return
+ else:
+ yield formatted_outputs(question, shared.model_name)
+ # RWKV has proper streaming, which is very nice.
+ # No need to generate 8 tokens at a time.
+ for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+ yield formatted_outputs(reply, shared.model_name)
+ finally:
+ t1 = time.time()
+ output = encode(reply)[0]
+ input_ids = encode(question)
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
+ return
original_question = question
if not (shared.args.chat or shared.args.cai_chat):