mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-09 12:09:04 +01:00
Make responses start faster by removing unnecessary cleanup calls (#6625)
This commit is contained in:
parent
64853f8509
commit
7b88724711
@ -65,7 +65,6 @@ class Iteratorize:
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
clear_torch_cache()
|
||||
self.q.put(self.sentinel)
|
||||
if self.c_callback:
|
||||
self.c_callback(ret)
|
||||
@ -84,22 +83,10 @@ class Iteratorize:
|
||||
return obj
|
||||
|
||||
def __del__(self):
|
||||
clear_torch_cache()
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop_now = True
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -90,6 +90,7 @@ def load_model(model_name, loader=None):
|
||||
raise ValueError
|
||||
|
||||
shared.args.loader = loader
|
||||
clear_torch_cache()
|
||||
output = load_func_map[loader](model_name)
|
||||
if type(output) is tuple:
|
||||
model, tokenizer = output
|
||||
|
@ -79,7 +79,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||
all_stop_strings += st
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
last_update = -1
|
||||
reply = ''
|
||||
@ -288,6 +287,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
||||
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
if shared.args.loader == 'Transformers':
|
||||
clear_torch_cache()
|
||||
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
|
||||
if k in state:
|
||||
@ -393,7 +395,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
|
||||
def generate_with_callback(callback=None, *args, **kwargs):
|
||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
shared.model.generate(**kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user