mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 00:46:53 +01:00
Add support for stopping strings in RWKV models
This commit is contained in:
parent
b57ffc2ec9
commit
7ff2cc8316
@ -160,13 +160,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
|
||||
evaluated_stopping_strings = ast.literal_eval(f"[{state['custom_stopping_strings']}]")
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
original_reply = reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
original_reply = apply_extensions(reply, 'output')
|
||||
reply = original_question + original_reply
|
||||
found_stop_string = next((s for s in stopping_strings + evaluated_stopping_strings if s in original_reply), None)
|
||||
if found_stop_string:
|
||||
yield formatted_outputs(str(reply).rsplit(found_stop_string, 1)[0], shared.model_name)
|
||||
break
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
|
Loading…
Reference in New Issue
Block a user