diff --git a/extensions/api/script.py b/extensions/api/script.py index dd48f58f..20562cc6 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -44,20 +44,21 @@ class Handler(BaseHTTPRequestHandler): generator = generate_reply( question = prompt, max_new_tokens = int(body.get('max_length', 200)), - do_sample=True, + do_sample=bool(body.get('do_sample', True)), temperature=float(body.get('temperature', 0.5)), top_p=float(body.get('top_p', 1)), typical_p=float(body.get('typical', 1)), repetition_penalty=float(body.get('rep_pen', 1.1)), encoder_repetition_penalty=1, top_k=int(body.get('top_k', 0)), - min_length=0, - no_repeat_ngram_size=0, - num_beams=1, - penalty_alpha=0, - length_penalty=1, - early_stopping=False, - seed=-1, + min_length=int(body.get('min_length', 0)), + no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)), + num_beams=int(body.get('num_beams',1)), + penalty_alpha=float(body.get('penalty_alpha', 0)), + length_penalty=float(body.get('length_penalty', 1)), + early_stopping=bool(body.get('early_stopping', False)), + seed=int(body.get('seed', -1)), + stopping_strings=body.get('stopping_strings', []), ) answer = ''