Fix stopping strings in the gradio API

This commit is contained in:
oobabooga 2023-04-19 13:52:21 -03:00
parent 69d50e2e86
commit 9d9ae62938
2 changed files with 5 additions and 3 deletions

View File

@ -36,10 +36,10 @@ params = {
'early_stopping': False, 'early_stopping': False,
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'custom_stopping_strings': [],
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [],
} }
# Input prompt # Input prompt

View File

@ -29,14 +29,16 @@ def generate_reply_wrapper(string):
'early_stopping': False, 'early_stopping': False,
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'custom_stopping_strings': [], 'custom_stopping_strings': '',
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [],
} }
params = json.loads(string) params = json.loads(string)
generate_params.update(params[1]) generate_params.update(params[1])
for i in generate_reply(params[0], generate_params): stopping_strings = generate_params.pop('stopping_strings')
for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings):
yield i yield i