diff --git a/extensions/api/script.py b/extensions/api/script.py index bd7c1900..dd48f58f 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -43,14 +43,14 @@ class Handler(BaseHTTPRequestHandler): generator = generate_reply( question = prompt, - max_new_tokens = body.get('max_length', 200), + max_new_tokens = int(body.get('max_length', 200)), do_sample=True, - temperature=body.get('temperature', 0.5), - top_p=body.get('top_p', 1), - typical_p=body.get('typical', 1), - repetition_penalty=body.get('rep_pen', 1.1), + temperature=float(body.get('temperature', 0.5)), + top_p=float(body.get('top_p', 1)), + typical_p=float(body.get('typical', 1)), + repetition_penalty=float(body.get('rep_pen', 1.1)), encoder_repetition_penalty=1, - top_k=body.get('top_k', 0), + top_k=int(body.get('top_k', 0)), min_length=0, no_repeat_ngram_size=0, num_beams=1, @@ -62,7 +62,10 @@ class Handler(BaseHTTPRequestHandler): answer = '' for a in generator: - answer = a[0] + if isinstance(a, str): + answer = a + else: + answer = a[0] response = json.dumps({ 'results': [{