fixup missing tfs top_a params, defaults reorg (#2443)

This commit is contained in:
matatonic 2023-05-30 20:52:33 -04:00 committed by GitHub
parent 9ab90d8b60
commit df50f077db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,6 +18,41 @@ params = {
debug = True if 'OPENEDAI_DEBUG' in os.environ else False debug = True if 'OPENEDAI_DEBUG' in os.environ else False
# Slightly different defaults for OpenAI's API
default_req_params = {
'max_new_tokens': 200,
'temperature': 1.0,
'top_p': 1.0,
'top_k': 1,
'repetition_penalty': 1.18,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': -1,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': 2048,
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': [],
}
# Optional, install the module and download the model to enable # Optional, install the module and download the model to enable
# v1/embeddings # v1/embeddings
try: try:
@ -194,46 +229,18 @@ class Handler(BaseHTTPRequestHandler):
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
req_params = { req_params = default_req_params.copy()
'max_new_tokens': max_tokens,
'temperature': default(body, 'temperature', 1.0),
'top_p': default(body, 'top_p', 1.0),
'top_k': default(body, 'best_of', 1),
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
'suffix': body.get('suffix', None),
'stream': default(body, 'stream', False),
'echo': default(body, 'echo', False),
#####################################################
'seed': shared.settings.get('seed', -1),
# int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
# unofficial, but it needs to get set anyways.
'truncation_length': truncation_length,
# no more args.
'add_bos_token': shared.settings.get('add_bos_token', True),
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'ban_eos_token': False,
'skip_special_tokens': True,
}
# fixup absolute 0.0's req_params['max_new_tokens'] = max_tokens
for par in ['temperature', 'repetition_penalty', 'encoder_repetition_penalty']: req_params['truncation_length'] = truncation_length
req_params[par] = clamp(req_params[par], 0.001, 1.999) req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
req_params['top_k'] = default(body, 'best_of', default_req_params['top_k'])
req_params['suffix'] = default(body, 'suffix', default_req_params['suffix'])
req_params['stream'] = default(body, 'stream', default_req_params['stream'])
req_params['echo'] = default(body, 'echo', default_req_params['echo'])
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
self.send_response(200) self.send_response(200)
if req_params['stream']: if req_params['stream']:
@ -550,37 +557,14 @@ class Handler(BaseHTTPRequestHandler):
token_count = len(encode(edit_task)[0]) token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count max_tokens = truncation_length - token_count
req_params = { req_params = default_req_params.copy()
'max_new_tokens': max_tokens,
'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999), req_params['max_new_tokens'] = max_tokens
'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0), req_params['truncation_length'] = truncation_length
'top_k': 1, req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
'repetition_penalty': 1.18, req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
'encoder_repetition_penalty': 1.0, req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
'suffix': None, req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
'stream': False,
'echo': False,
'seed': shared.settings.get('seed', -1),
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': truncation_length,
'add_bos_token': shared.settings.get('add_bos_token', True),
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': [],
}
if debug: if debug:
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})