mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
fixup missing tfs top_a params, defaults reorg (#2443)
This commit is contained in:
parent
9ab90d8b60
commit
df50f077db
@ -18,6 +18,41 @@ params = {
|
||||
|
||||
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
||||
|
||||
# Slightly different defaults for OpenAI's API
|
||||
default_req_params = {
|
||||
'max_new_tokens': 200,
|
||||
'temperature': 1.0,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'repetition_penalty': 1.18,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'suffix': None,
|
||||
'stream': False,
|
||||
'echo': False,
|
||||
'seed': -1,
|
||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||
'truncation_length': 2048,
|
||||
'add_bos_token': True,
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'tfs': 1.0,
|
||||
'top_a': 0.0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0.0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'custom_stopping_strings': [],
|
||||
}
|
||||
|
||||
# Optional, install the module and download the model to enable
|
||||
# v1/embeddings
|
||||
try:
|
||||
@ -194,46 +229,18 @@ class Handler(BaseHTTPRequestHandler):
|
||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
||||
|
||||
req_params = {
|
||||
'max_new_tokens': max_tokens,
|
||||
'temperature': default(body, 'temperature', 1.0),
|
||||
'top_p': default(body, 'top_p', 1.0),
|
||||
'top_k': default(body, 'best_of', 1),
|
||||
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
|
||||
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
|
||||
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
|
||||
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
|
||||
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
|
||||
'suffix': body.get('suffix', None),
|
||||
'stream': default(body, 'stream', False),
|
||||
'echo': default(body, 'echo', False),
|
||||
#####################################################
|
||||
'seed': shared.settings.get('seed', -1),
|
||||
# int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
|
||||
# unofficial, but it needs to get set anyways.
|
||||
'truncation_length': truncation_length,
|
||||
# no more args.
|
||||
'add_bos_token': shared.settings.get('add_bos_token', True),
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0.0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
}
|
||||
req_params = default_req_params.copy()
|
||||
|
||||
# fixup absolute 0.0's
|
||||
for par in ['temperature', 'repetition_penalty', 'encoder_repetition_penalty']:
|
||||
req_params[par] = clamp(req_params[par], 0.001, 1.999)
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
req_params['truncation_length'] = truncation_length
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
|
||||
req_params['top_k'] = default(body, 'best_of', default_req_params['top_k'])
|
||||
req_params['suffix'] = default(body, 'suffix', default_req_params['suffix'])
|
||||
req_params['stream'] = default(body, 'stream', default_req_params['stream'])
|
||||
req_params['echo'] = default(body, 'echo', default_req_params['echo'])
|
||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
||||
|
||||
self.send_response(200)
|
||||
if req_params['stream']:
|
||||
@ -550,37 +557,14 @@ class Handler(BaseHTTPRequestHandler):
|
||||
token_count = len(encode(edit_task)[0])
|
||||
max_tokens = truncation_length - token_count
|
||||
|
||||
req_params = {
|
||||
'max_new_tokens': max_tokens,
|
||||
'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999),
|
||||
'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0),
|
||||
'top_k': 1,
|
||||
'repetition_penalty': 1.18,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'suffix': None,
|
||||
'stream': False,
|
||||
'echo': False,
|
||||
'seed': shared.settings.get('seed', -1),
|
||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||
'truncation_length': truncation_length,
|
||||
'add_bos_token': shared.settings.get('add_bos_token', True),
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0.0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'custom_stopping_strings': [],
|
||||
}
|
||||
req_params = default_req_params.copy()
|
||||
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
req_params['truncation_length'] = truncation_length
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
|
||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
||||
|
||||
if debug:
|
||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
Loading…
Reference in New Issue
Block a user