This commit is contained in:
oobabooga 2024-01-07 09:30:55 -08:00
parent d93db3b486
commit c4c7fc4ab3
3 changed files with 5 additions and 6 deletions

View File

@ -6,6 +6,7 @@ params = {
"maximum_temperature": 2, "maximum_temperature": 2,
} }
def convert_to_dynatemp(): def convert_to_dynatemp():
temperature = 0.5 * (params["minimum_temperature"] + params["maximum_temperature"]) temperature = 0.5 * (params["minimum_temperature"] + params["maximum_temperature"])
dynatemp = params["maximum_temperature"] - temperature dynatemp = params["maximum_temperature"] - temperature
@ -22,7 +23,7 @@ def state_modifier(state):
temperature, dynatemp = convert_to_dynatemp() temperature, dynatemp = convert_to_dynatemp()
state["temperature"] = temperature state["temperature"] = temperature
state["dynatemp"] = dynatemp state["dynatemp"] = dynatemp
return state return state

View File

@ -16,7 +16,7 @@ global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper): class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynatemp: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, temperature: float, dynatemp: float):
if not isinstance(temperature, float) or not (temperature > 0): if not isinstance(temperature, float) or not (temperature > 0):
except_msg = ( except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token " f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -29,8 +29,6 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
self.temperature = temperature self.temperature = temperature
self.dynatemp = dynatemp self.dynatemp = dynatemp
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -247,9 +247,9 @@ def install_webui():
if selected_gpu == "INTEL": if selected_gpu == "INTEL":
# Install oneAPI dependencies via conda # Install oneAPI dependencies via conda
print_big_message("Installing Intel oneAPI runtime libraries.") print_big_message("Installing Intel oneAPI runtime libraries.")
run_cmd(f"conda install -y -c intel dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0") run_cmd("conda install -y -c intel dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0")
# Install libuv required by Intel-patched torch # Install libuv required by Intel-patched torch
run_cmd(f"conda install -y libuv") run_cmd("conda install -y libuv")
# Install the webui requirements # Install the webui requirements
update_requirements(initial_installation=True) update_requirements(initial_installation=True)