mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-09 20:19:06 +01:00
Make responses start faster by removing unnecessary cleanup calls (#6625)
This commit is contained in:
parent
64853f8509
commit
7b88724711
@ -65,7 +65,6 @@ class Iteratorize:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
clear_torch_cache()
|
|
||||||
self.q.put(self.sentinel)
|
self.q.put(self.sentinel)
|
||||||
if self.c_callback:
|
if self.c_callback:
|
||||||
self.c_callback(ret)
|
self.c_callback(ret)
|
||||||
@ -84,22 +83,10 @@ class Iteratorize:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
clear_torch_cache()
|
pass
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.stop_now = True
|
self.stop_now = True
|
||||||
clear_torch_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def clear_torch_cache():
|
|
||||||
gc.collect()
|
|
||||||
if not shared.args.cpu:
|
|
||||||
if is_torch_xpu_available():
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
elif is_torch_npu_available():
|
|
||||||
torch.npu.empty_cache()
|
|
||||||
else:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
@ -90,6 +90,7 @@ def load_model(model_name, loader=None):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
shared.args.loader = loader
|
shared.args.loader = loader
|
||||||
|
clear_torch_cache()
|
||||||
output = load_func_map[loader](model_name)
|
output = load_func_map[loader](model_name)
|
||||||
if type(output) is tuple:
|
if type(output) is tuple:
|
||||||
model, tokenizer = output
|
model, tokenizer = output
|
||||||
|
@ -79,7 +79,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
all_stop_strings += st
|
all_stop_strings += st
|
||||||
|
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
clear_torch_cache()
|
|
||||||
seed = set_manual_seed(state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
last_update = -1
|
last_update = -1
|
||||||
reply = ''
|
reply = ''
|
||||||
@ -288,6 +287,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
|||||||
|
|
||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||||
|
if shared.args.loader == 'Transformers':
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
|
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
|
||||||
if k in state:
|
if k in state:
|
||||||
@ -393,7 +395,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
|
|
||||||
def generate_with_callback(callback=None, *args, **kwargs):
|
def generate_with_callback(callback=None, *args, **kwargs):
|
||||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
clear_torch_cache()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
shared.model.generate(**kwargs)
|
shared.model.generate(**kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user