mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 00:46:53 +01:00
Add support for stopping strings in RWKV models
This commit is contained in:
parent
b57ffc2ec9
commit
7ff2cc8316
@ -160,13 +160,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||||||
else:
|
else:
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
yield formatted_outputs(question, shared.model_name)
|
yield formatted_outputs(question, shared.model_name)
|
||||||
|
evaluated_stopping_strings = ast.literal_eval(f"[{state['custom_stopping_strings']}]")
|
||||||
# RWKV has proper streaming, which is very nice.
|
# RWKV has proper streaming, which is very nice.
|
||||||
# No need to generate 8 tokens at a time.
|
# No need to generate 8 tokens at a time.
|
||||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
output = original_question + reply
|
output = original_question + reply
|
||||||
|
original_reply = reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, 'output')
|
original_reply = apply_extensions(reply, 'output')
|
||||||
|
reply = original_question + original_reply
|
||||||
|
found_stop_string = next((s for s in stopping_strings + evaluated_stopping_strings if s in original_reply), None)
|
||||||
|
if found_stop_string:
|
||||||
|
yield formatted_outputs(str(reply).rsplit(found_stop_string, 1)[0], shared.model_name)
|
||||||
|
break
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
Loading…
Reference in New Issue
Block a user