Make responses start faster by removing unnecessary cleanup calls (#6625)

This commit is contained in:
oobabooga 2025-01-01 18:33:38 -03:00 committed by GitHub
parent 64853f8509
commit 7b88724711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 16 deletions

View File

@ -65,7 +65,6 @@ class Iteratorize:
traceback.print_exc() traceback.print_exc()
pass pass
clear_torch_cache()
self.q.put(self.sentinel) self.q.put(self.sentinel)
if self.c_callback: if self.c_callback:
self.c_callback(ret) self.c_callback(ret)
@ -84,22 +83,10 @@ class Iteratorize:
return obj return obj
def __del__(self): def __del__(self):
clear_torch_cache() pass
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True self.stop_now = True
clear_torch_cache()
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()

View File

@ -90,6 +90,7 @@ def load_model(model_name, loader=None):
raise ValueError raise ValueError
shared.args.loader = loader shared.args.loader = loader
clear_torch_cache()
output = load_func_map[loader](model_name) output = load_func_map[loader](model_name)
if type(output) is tuple: if type(output) is tuple:
model, tokenizer = output model, tokenizer = output

View File

@ -79,7 +79,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
all_stop_strings += st all_stop_strings += st
shared.stop_everything = False shared.stop_everything = False
clear_torch_cache()
seed = set_manual_seed(state['seed']) seed = set_manual_seed(state['seed'])
last_update = -1 last_update = -1
reply = '' reply = ''
@ -288,6 +287,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
if shared.args.loader == 'Transformers':
clear_torch_cache()
generate_params = {} generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']: for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
if k in state: if k in state:
@ -393,7 +395,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
def generate_with_callback(callback=None, *args, **kwargs): def generate_with_callback(callback=None, *args, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback)) kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
shared.model.generate(**kwargs) shared.model.generate(**kwargs)