From 9d9ae6293833ce31bbb5ed5d9a04b033d1e3896d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 19 Apr 2023 13:52:21 -0300 Subject: [PATCH] Fix stopping strings in the gradio API --- api-example.py | 2 +- modules/api.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/api-example.py b/api-example.py index 5138eb88..eff610c1 100644 --- a/api-example.py +++ b/api-example.py @@ -36,10 +36,10 @@ params = { 'early_stopping': False, 'seed': -1, 'add_bos_token': True, - 'custom_stopping_strings': [], 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, + 'stopping_strings': [], } # Input prompt diff --git a/modules/api.py b/modules/api.py index b57cfe88..9de8e25d 100644 --- a/modules/api.py +++ b/modules/api.py @@ -29,14 +29,16 @@ def generate_reply_wrapper(string): 'early_stopping': False, 'seed': -1, 'add_bos_token': True, - 'custom_stopping_strings': [], + 'custom_stopping_strings': '', 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, + 'stopping_strings': [], } params = json.loads(string) generate_params.update(params[1]) - for i in generate_reply(params[0], generate_params): + stopping_strings = generate_params.pop('stopping_strings') + for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings): yield i