diff --git a/modules/callbacks.py b/modules/callbacks.py index 2b039ef1..2c04cc53 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -65,7 +65,6 @@ class Iteratorize: traceback.print_exc() pass - clear_torch_cache() self.q.put(self.sentinel) if self.c_callback: self.c_callback(ret) @@ -84,22 +83,10 @@ class Iteratorize: return obj def __del__(self): - clear_torch_cache() + pass def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop_now = True - clear_torch_cache() - - -def clear_torch_cache(): - gc.collect() - if not shared.args.cpu: - if is_torch_xpu_available(): - torch.xpu.empty_cache() - elif is_torch_npu_available(): - torch.npu.empty_cache() - else: - torch.cuda.empty_cache() diff --git a/modules/models.py b/modules/models.py index 3e8ae3f2..7a52c07c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -90,6 +90,7 @@ def load_model(model_name, loader=None): raise ValueError shared.args.loader = loader + clear_torch_cache() output = load_func_map[loader](model_name) if type(output) is tuple: model, tokenizer = output diff --git a/modules/text_generation.py b/modules/text_generation.py index 86245098..c999fa81 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -79,7 +79,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap all_stop_strings += st shared.stop_everything = False - clear_torch_cache() seed = set_manual_seed(state['seed']) last_update = -1 reply = '' @@ -288,6 +287,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): + if shared.args.loader == 'Transformers': + clear_torch_cache() + generate_params = {} for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']: if k in state: @@ -393,7 +395,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings def generate_with_callback(callback=None, *args, **kwargs): kwargs['stopping_criteria'].append(Stream(callback_func=callback)) - clear_torch_cache() with torch.no_grad(): shared.model.generate(**kwargs)