mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Print generation parameters with --verbose (HF only)
This commit is contained in:
parent
c4c7fc4ab3
commit
cf820c69c5
@ -1,6 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import copy
|
import copy
|
||||||
import html
|
import html
|
||||||
|
import pprint
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -65,7 +66,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
all_stop_strings += st
|
all_stop_strings += st
|
||||||
|
|
||||||
if shared.args.verbose:
|
if shared.args.verbose:
|
||||||
print(f'\n\n{question}\n--------------------\n')
|
logger.info("PROMPT=")
|
||||||
|
print(question)
|
||||||
|
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
@ -283,7 +285,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
|||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
if state['negative_prompt'] != '':
|
if state['negative_prompt'] != '':
|
||||||
@ -342,6 +344,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
apply_extensions('logits_processor', processor, input_ids)
|
apply_extensions('logits_processor', processor, input_ids)
|
||||||
generate_params['logits_processor'] = processor
|
generate_params['logits_processor'] = processor
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
logger.info("GENERATE_PARAMS=")
|
||||||
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(generate_params)
|
||||||
|
print()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat and not shared.is_seq2seq:
|
if not is_chat and not shared.is_seq2seq:
|
||||||
|
Loading…
Reference in New Issue
Block a user