Give API extension access to all generate_reply parameters (#744)

* Make every parameter of the generate_reply function parameterizable

* Add stopping strings as parameterizable
This commit is contained in:
Niels Mündler 2023-04-03 18:31:12 +02:00 committed by GitHub
parent 9318e16ed5
commit 7aab88bcc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,20 +44,21 @@ class Handler(BaseHTTPRequestHandler):
generator = generate_reply( generator = generate_reply(
question = prompt, question = prompt,
max_new_tokens = int(body.get('max_length', 200)), max_new_tokens = int(body.get('max_length', 200)),
do_sample=True, do_sample=bool(body.get('do_sample', True)),
temperature=float(body.get('temperature', 0.5)), temperature=float(body.get('temperature', 0.5)),
top_p=float(body.get('top_p', 1)), top_p=float(body.get('top_p', 1)),
typical_p=float(body.get('typical', 1)), typical_p=float(body.get('typical', 1)),
repetition_penalty=float(body.get('rep_pen', 1.1)), repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1, encoder_repetition_penalty=1,
top_k=int(body.get('top_k', 0)), top_k=int(body.get('top_k', 0)),
min_length=0, min_length=int(body.get('min_length', 0)),
no_repeat_ngram_size=0, no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
num_beams=1, num_beams=int(body.get('num_beams',1)),
penalty_alpha=0, penalty_alpha=float(body.get('penalty_alpha', 0)),
length_penalty=1, length_penalty=float(body.get('length_penalty', 1)),
early_stopping=False, early_stopping=bool(body.get('early_stopping', False)),
seed=-1, seed=int(body.get('seed', -1)),
stopping_strings=body.get('stopping_strings', []),
) )
answer = '' answer = ''