From 020fe7b50b73a82a6124d7d3b19dfca080b20ccc Mon Sep 17 00:00:00 2001 From: IJumpAround <30680324+IJumpAround@users.noreply.github.com> Date: Mon, 8 May 2023 21:55:41 -0400 Subject: [PATCH] Remove mutable defaults from function signature. (#1663) --- modules/GPTQ_loader.py | 3 ++- modules/RWKV.py | 6 +++--- modules/callbacks.py | 4 ++-- modules/text_generation.py | 8 ++++---- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 8142c34e..87a4f524 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -35,7 +35,8 @@ except ImportError: # This function is a replacement for the load_quant function in the # GPTQ-for_LLaMa repository. It supports more models and branches. -def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128, eval=True): +def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=None, kernel_switch_threshold=128, eval=True): + exclude_layers = exclude_layers or ['lm_head'] def noop(*args, **kwargs): pass diff --git a/modules/RWKV.py b/modules/RWKV.py index 0405230e..957bc004 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -34,15 +34,15 @@ class RWKVModel: result.pipeline = pipeline return result - def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None): args = PIPELINE_ARGS( temperature=temperature, top_p=top_p, top_k=top_k, alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3) alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3) - token_ban=token_ban, # ban the generation of some tokens - token_stop=token_stop + token_ban=token_ban or [0], # ban the generation of some tokens + token_stop=token_stop or [] ) return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) diff --git a/modules/callbacks.py b/modules/callbacks.py index fb87ad56..5996ba4e 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -55,12 +55,12 @@ class Iteratorize: Adapted from: https://stackoverflow.com/a/9969000 """ - def __init__(self, func, kwargs={}, callback=None): + def __init__(self, func, kwargs=None, callback=None): self.mfunc = func self.c_callback = callback self.q = Queue() self.sentinel = object() - self.kwargs = kwargs + self.kwargs = kwargs or {} self.stop_now = False def _callback(val): diff --git a/modules/text_generation.py b/modules/text_generation.py index ba3f16b9..8a980425 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -142,7 +142,7 @@ def stop_everything_event(): shared.stop_everything = True -def generate_reply(question, state, eos_token=None, stopping_strings=[]): +def generate_reply(question, state, eos_token=None, stopping_strings=None): state = apply_extensions('state', state) generate_func = apply_extensions('custom_generate_reply') if generate_func is None: @@ -173,7 +173,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): yield formatted_outputs(reply, shared.model_name) -def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]): +def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None): generate_params = {} for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: generate_params[k] = state[k] @@ -272,7 +272,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, return -def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]): +def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None): seed = set_manual_seed(state['seed']) generate_params = {'token_count': state['max_new_tokens']} for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: @@ -309,7 +309,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No return -def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]): +def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=None): generate_params = {} for k in ['max_new_tokens', 'do_sample', 'temperature']: generate_params[k] = state[k]