diff --git a/extensions/api/script.py b/extensions/api/script.py index 17878c5d..ab004542 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -34,9 +34,7 @@ class Handler(BaseHTTPRequestHandler): prompt = body['prompt'] prompt_lines = [k.strip() for k in prompt.split('\n')] - max_context = body.get('max_context_length', 2048) - while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: prompt_lines.pop(0) @@ -58,17 +56,13 @@ class Handler(BaseHTTPRequestHandler): 'early_stopping': bool(body.get('early_stopping', False)), 'seed': int(body.get('seed', -1)), 'add_bos_token': int(body.get('add_bos_token', True)), - 'custom_stopping_strings': body.get('custom_stopping_strings', []), 'truncation_length': int(body.get('truncation_length', 2048)), 'ban_eos_token': bool(body.get('ban_eos_token', False)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), + 'stopping_strings': body.get('stopping_strings', []), } - - generator = generate_reply( - prompt, - generate_params, - ) - + stopping_strings = generate_params.pop('stopping_strings') + generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings) answer = '' for a in generator: if isinstance(a, str):