Add proper streaming to RWKV

This commit is contained in:
oobabooga 2023-03-07 18:17:56 -03:00
parent 8660227e1b
commit 19a34941ed
2 changed files with 52 additions and 8 deletions

View File

@ -1,5 +1,7 @@
import os import os
from pathlib import Path from pathlib import Path
from queue import Queue
from threading import Thread
import numpy as np import numpy as np
from tokenizers import Tokenizer from tokenizers import Tokenizer
@ -33,7 +35,7 @@ class RWKVModel:
result.pipeline = pipeline result.pipeline = pipeline
return result return result
def generate(self, context, token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
args = PIPELINE_ARGS( args = PIPELINE_ARGS(
temperature = temperature, temperature = temperature,
top_p = top_p, top_p = top_p,
@ -46,6 +48,13 @@ class RWKVModel:
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
def generate_with_streaming(self, **kwargs):
iterable = Iteratorize(self.generate, kwargs, callback=None)
reply = kwargs['context']
for token in iterable:
reply += token
yield reply
class RWKVTokenizer: class RWKVTokenizer:
def __init__(self): def __init__(self):
pass pass
@ -64,3 +73,38 @@ class RWKVTokenizer:
def decode(self, ids): def decode(self, ids):
return self.tokenizer.decode(ids) return self.tokenizer.decode(ids)
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}, callback=None):
self.mfunc=func
self.c_callback=callback
self.q = Queue(maxsize=1)
self.sentinel = object()
self.kwargs = kwargs
def _callback(val):
self.q.put(val)
def gentask():
ret = self.mfunc(callback=_callback, **self.kwargs)
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
Thread(target=gentask).start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True,None)
if obj is self.sentinel:
raise StopIteration
else:
return obj

View File

@ -92,17 +92,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# separately and terminate the function call earlier # separately and terminate the function call earlier
if shared.is_RWKV: if shared.is_RWKV:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds.")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
else: else:
yield formatted_outputs(question, shared.model_name) yield formatted_outputs(question, shared.model_name)
for i in tqdm(range(max_new_tokens//8+1)): # RWKV has proper streaming, which is very nice.
clear_torch_cache() # No need to generate 8 tokens at a time.
reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p, top_k=top_k) for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
question = reply
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds.")
return return
original_question = question original_question = question