diff --git a/download-model.py b/download-model.py index f5d49064..0f650516 100644 --- a/download-model.py +++ b/download-model.py @@ -62,7 +62,7 @@ class ModelDownloader: is_lora = False while True: url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") - r = self.s.get(url, timeout=20) + r = self.s.get(url, timeout=10) r.raise_for_status() content = r.content @@ -136,7 +136,7 @@ class ModelDownloader: if output_path.exists() and not start_from_scratch: # Check if the file has already been downloaded completely - r = self.s.get(url, stream=True, timeout=20) + r = self.s.get(url, stream=True, timeout=10) total_size = int(r.headers.get('content-length', 0)) if output_path.stat().st_size >= total_size: return @@ -145,7 +145,7 @@ class ModelDownloader: headers = {'Range': f'bytes={output_path.stat().st_size}-'} mode = 'ab' - with self.s.get(url, stream=True, headers=headers, timeout=20) as r: + with self.s.get(url, stream=True, headers=headers, timeout=10) as r: r.raise_for_status() # Do not continue the download if the request was unsuccessful total_size = int(r.headers.get('content-length', 0)) block_size = 1024 * 1024 # 1MB diff --git a/extensions/perplexity_colors/script.py b/extensions/perplexity_colors/script.py new file mode 100644 index 00000000..84b62a30 --- /dev/null +++ b/extensions/perplexity_colors/script.py @@ -0,0 +1,215 @@ +import gradio +import torch +from transformers import LogitsProcessor +import numpy as np + +from modules import shared + +params = { + 'color_by_perplexity': False, + 'color_by_probability': False, + 'ppl_scale': 15.0, # No slider for this right now, because I don't think it really needs to be changed. Very large perplexity scores don't show up often. + #'probability_dropdown': False +} + +class PerplexityLogits(LogitsProcessor): + def __init__(self, verbose=False): + self.generated_token_ids = [] + self.selected_probs = [] + self.top_token_ids_list = [] + self.top_probs_list = [] + self.perplexities_list = [] + self.last_probs = None + self.verbose = verbose + + def __call__(self, input_ids, scores): + probs = torch.softmax(scores, dim=-1, dtype=torch.float) + log_probs = torch.nan_to_num(torch.log(probs)) + entropy = -torch.sum(probs*log_probs) + entropy = entropy.cpu().numpy() + perplexity = round(float(np.exp(entropy)), 4) + self.perplexities_list.append(perplexity) + last_token_id = int(input_ids[0][-1].cpu().numpy().item()) + # Store the generated tokens (not sure why this isn't accessible in the output endpoint!) + self.generated_token_ids.append(last_token_id) + # Get last probability, and add to the list if it wasn't there + if len(self.selected_probs) > 0: + # Is the selected token in the top tokens? + if self.verbose: + print(shared.tokenizer.decode(last_token_id)) + print([shared.tokenizer.decode(token_id) for token_id in self.top_token_ids_list[-1]]) + print(self.top_probs_list[-1]) + if last_token_id in self.top_token_ids_list[-1]: + idx = self.top_token_ids_list[-1].index(last_token_id) + self.selected_probs.append(self.top_probs_list[-1][idx]) + else: + self.top_token_ids_list[-1].append(last_token_id) + last_prob = round(float(self.last_probs[last_token_id]), 4) + self.top_probs_list[-1].append(last_prob) + self.selected_probs.append(last_prob) + else: + self.selected_probs.append(1.0) # Placeholder for the last token of the prompt + + if self.verbose: + pplbar = "-" + if not np.isnan(perplexity): + pplbar = "*"*round(perplexity) + print(f"{last_token}\t{perplexity:.2f}\t{pplbar}") + + # Get top 5 probabilities + top_tokens_and_probs = torch.topk(probs, 5) + top_probs = top_tokens_and_probs.values.cpu().numpy().astype(float).tolist() + top_token_ids = top_tokens_and_probs.indices.cpu().numpy().astype(int).tolist() + + self.top_token_ids_list.append(top_token_ids) + self.top_probs_list.append(top_probs) + + probs = probs.cpu().numpy().flatten() + self.last_probs = probs # Need to keep this as a reference for top probs + + # Doesn't actually modify the logits! + return scores + +# Stores the perplexity and top probabilities +ppl_logits_processor = None + +def logits_processor_modifier(logits_processor_list, input_ids): + global ppl_logits_processor + ppl_logits_processor = PerplexityLogits() + logits_processor_list.append(ppl_logits_processor) + +def output_modifier(text): + global ppl_logits_processor + + # TODO: It's probably more efficient to do this above rather than modifying all these lists + # Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation + perplexities = ppl_logits_processor.perplexities_list[:-1] + top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1] + top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids] for top_token_ids in top_token_ids_list] + top_probs_list = ppl_logits_processor.top_probs_list[:-1] + # Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt + gen_token_ids = ppl_logits_processor.generated_token_ids[1:] + gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids] + sel_probs = ppl_logits_processor.selected_probs[1:] + + end_part = '' # Helps with finding the index after replacing part of the text. + in_code = False # Since the tags mess up code blocks, avoid coloring while inside a code block, based on finding tokens with '`' in them + + if params['color_by_probability'] and params['color_by_perplexity']: + i = 0 + for token, prob, ppl, top_tokens, top_probs in zip(gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list): + if '`' in token: + in_code = not in_code + continue + if in_code: + continue + color = probability_perplexity_color_scale(prob, ppl) + if token in text[i:]: + text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1) + i += text[i:].find(end_part) + len(end_part) + elif params['color_by_perplexity']: + i = 0 + for token, ppl, top_tokens, top_probs in zip(gen_tokens, perplexities, top_tokens_list, top_probs_list): + if '`' in token: + in_code = not in_code + continue + if in_code: + continue + color = perplexity_color_scale(ppl) + if token in text[i:]: + text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1) + i += text[i:].find(end_part) + len(end_part) + elif params['color_by_probability']: + i = 0 + for token, prob, top_tokens, top_probs in zip(gen_tokens, sel_probs, top_tokens_list, top_probs_list): + if '`' in token: + in_code = not in_code + continue + if in_code: + continue + color = probability_color_scale(prob) + if token in text[i:]: + text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1) + i += text[i:].find(end_part) + len(end_part) + + print('Average perplexity:', round(np.mean(perplexities), 4)) + return text + +# Green-yellow-red color scale +def probability_color_scale(prob): + rv = 0 + gv = 0 + if prob <= 0.5: + rv = 'ff' + gv = hex(int(255*prob*2))[2:] + if len(gv) < 2: + gv = '0'*(2 - len(gv)) + gv + else: + rv = hex(int(255 - 255*(prob - 0.5)*2))[2:] + gv = 'ff' + if len(rv) < 2: + rv = '0'*(2 - len(rv)) + rv + return rv + gv + '00' + +# Red component only, white for 0 perplexity (sorry if you're not in dark mode) +def perplexity_color_scale(ppl): + value = hex(max(int(255.0 - params['ppl_scale']*(float(ppl)-1.0)), 0))[2:] + if len(value) < 2: + value = '0'*(2 - len(value)) + value + return 'ff' + value + value + +# Green-yellow-red for probability and blue component for perplexity +def probability_perplexity_color_scale(prob, ppl): + rv = 0 + gv = 0 + bv = hex(min(max(int(params['ppl_scale']*(float(ppl)-1.0)), 0), 255))[2:] + if len(bv) < 2: + bv = '0'*(2 - len(bv)) + bv + if prob <= 0.5: + rv = 'ff' + gv = hex(int(255*prob*2))[2:] + if len(gv) < 2: + gv = '0'*(2 - len(gv)) + gv + else: + rv = hex(int(255 - 255*(prob - 0.5)*2))[2:] + gv = 'ff' + if len(rv) < 2: + rv = '0'*(2 - len(rv)) + rv + return rv + gv + bv + +def add_color_html(token, color): + return f'{token}' + +""" +# This is still very broken at the moment, needs CSS too but I'm not very good at CSS (and neither is GPT-4 apparently) so I still need to figure that out. +def add_dropdown_html(token, color, top_tokens, top_probs): + html = f'{token}' + return html +""" + +def ui(): + color_by_ppl_check = gradio.Checkbox(value=False, label="Color by perplexity", info="Higher perplexity is more red. If also showing probability, higher perplexity has more blue component.") + def update_color_by_ppl_check(x): + params.update({'color_by_perplexity': x}) + color_by_ppl_check.change(update_color_by_ppl_check, color_by_ppl_check, None) + + color_by_prob_check = gradio.Checkbox(value=False, label="Color by probability", info="Green-yellow-red linear scale, with 100% green, 50% yellow, 0% red.") + def update_color_by_prob_check(x): + params.update({'color_by_probability': x}) + color_by_prob_check.change(update_color_by_prob_check, color_by_prob_check, None) + + # Doesn't work yet... + """ + prob_dropdown_check = gradio.Checkbox(value=False, label="Probability dropdown") + def update_prob_dropdown_check(x): + params.update({'probability_dropdown': x}) + prob_dropdown_check.change(update_prob_dropdown_check, prob_dropdown_check, None) + """ diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py index a25c3f42..fd775b4a 100644 --- a/modules/exllama_hf.py +++ b/modules/exllama_hf.py @@ -29,6 +29,7 @@ class ExllamaHF(PreTrainedModel): super().__init__(PretrainedConfig()) self.ex_config = config self.ex_model = ExLlama(self.ex_config) + self.ex_cache = ExLlamaCache(self.ex_model) self.generation_config = GenerationConfig() self.lora = None @@ -52,11 +53,20 @@ class ExllamaHF(PreTrainedModel): labels = kwargs.get('labels', None) seq = kwargs['input_ids'][0].tolist() cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None - if cache is None: - cache = ExLlamaCache(self.ex_model) - self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora) - logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device) + if labels is None: + if cache is None: + self.ex_cache.current_seq_len = 0 + cache = self.ex_cache + self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora) + + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device) + else: + if cache is None: + self.ex_cache.current_seq_len = 0 + cache = self.ex_cache + + logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), cache, last_id_only=False, lora=self.lora) loss = None if labels is not None: @@ -71,7 +81,7 @@ class ExllamaHF(PreTrainedModel): shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None) + return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): diff --git a/modules/extensions.py b/modules/extensions.py index 8705101a..faf6cf6d 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history): return history -# Extension functions that override the default tokenizer output - currently only the first one will work +# Extension functions that override the default tokenizer output - The order of execution is not defined def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): for extension, _ in iterator(): if hasattr(extension, function_name): - return getattr(extension, function_name)(state, prompt, input_ids, input_embeds) + prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) return prompt, input_ids, input_embeds +# Allow extensions to add their own logits processors to the stack being run. +# Each extension would call `processor_list.append({their LogitsProcessor}())`. +def _apply_logits_processor_extensions(function_name, processor_list, input_ids): + for extension, _ in iterator(): + if hasattr(extension, function_name): + getattr(extension, function_name)(processor_list, input_ids) + + # Get prompt length in tokens after applying extension functions which override the default tokenizer output # currently only the first one will work def _apply_custom_tokenized_length(prompt): @@ -183,6 +191,7 @@ EXTENSION_MAP = { "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), + 'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'), "input_hijack": _apply_input_hijack, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_reply": _apply_custom_generate_reply, diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py new file mode 100644 index 00000000..12212fec --- /dev/null +++ b/modules/llamacpp_hf.py @@ -0,0 +1,103 @@ +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from llama_cpp import Llama +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from modules import shared +from modules.logging_colors import logger + + +class LlamacppHF(PreTrainedModel): + def __init__(self, model): + super().__init__(PretrainedConfig()) + self.model = model + self.generation_config = GenerationConfig() + self.cache = None + + def _validate_model_class(self): + pass + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + pass + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {'input_ids': input_ids, **kwargs} + + @property + def device(self) -> torch.device: + return torch.device(0) + + def __call__(self, *args, **kwargs): + # TODO: Some decoding methods (such as Contrastive Search) may not work at this time + assert len(args) == 0, 'no *args should be passed to forward' + use_cache = kwargs.get('use_cache', True) + labels = kwargs.get('labels', None) + seq = kwargs['input_ids'][0].tolist() + cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None + + # Make the forward call + seq_tensor = torch.tensor(seq) + self.cache = seq_tensor + if labels is None: + if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]): + self.model.reset() + self.model.eval(seq) + else: + self.model.eval([seq[-1]]) + + logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device) + else: + self.model.reset() + self.model.eval(seq) + logits = torch.tensor(self.model.eval_logits) + logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device) + + # Based on transformers/models/llama/modeling_llama.py + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) + if path.is_file(): + model_file = path + else: + model_file = list(path.glob('*ggml*.bin'))[0] + + logger.info(f"llama.cpp weights detected: {model_file}\n") + params = { + 'model_path': str(model_file), + 'n_ctx': shared.args.n_ctx, + 'seed': int(shared.args.llama_cpp_seed), + 'n_threads': shared.args.threads or None, + 'n_batch': shared.args.n_batch, + 'use_mmap': not shared.args.no_mmap, + 'use_mlock': shared.args.mlock, + 'low_vram': shared.args.low_vram, + 'n_gpu_layers': shared.args.n_gpu_layers, + 'logits_all': True, + } + + model = Llama(**params) + return LlamacppHF(model) diff --git a/modules/loaders.py b/modules/loaders.py index bb539564..da38c2f5 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -38,6 +38,17 @@ loaders_and_params = { 'mlock', 'llama_cpp_seed', ], + 'llamacpp_HF': [ + 'n_ctx', + 'n_gpu_layers', + 'n_batch', + 'threads', + 'no_mmap', + 'low_vram', + 'mlock', + 'llama_cpp_seed', + 'llamacpp_HF_info', + ], 'Transformers': [ 'cpu_memory', 'gpu_memory', diff --git a/modules/models.py b/modules/models.py index 2029e3a0..9d9ba951 100644 --- a/modules/models.py +++ b/modules/models.py @@ -55,6 +55,7 @@ def load_model(model_name, loader=None): 'AutoGPTQ': AutoGPTQ_loader, 'GPTQ-for-LLaMa': GPTQ_loader, 'llama.cpp': llamacpp_loader, + 'llamacpp_HF': llamacpp_HF_loader, 'FlexGen': flexgen_loader, 'RWKV': RWKV_loader, 'ExLlama': ExLlama_loader, @@ -268,6 +269,27 @@ def llamacpp_loader(model_name): return model, tokenizer +def llamacpp_HF_loader(model_name): + from modules.llamacpp_hf import LlamacppHF + + for fname in ["oobabooga_llama-tokenizer", "llama-tokenizer"]: + path = Path(f'{shared.args.model_dir}/{fname}') + if path.exists(): + break + else: + logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.") + return None, None + + tokenizer = AutoTokenizer.from_pretrained( + path, + trust_remote_code=shared.args.trust_remote_code, + use_fast=False + ) + + model = LlamacppHF.from_pretrained(model_name) + return model, tokenizer + + def GPTQ_loader(model_name): # Monkey patch diff --git a/modules/shared.py b/modules/shared.py index b54f9aac..f20b9fcd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -214,6 +214,8 @@ def fix_loader_name(name): name = name.lower() if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']: return 'llama.cpp' + if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']: + return 'llamacpp_HF' elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: return 'Transformers' elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']: diff --git a/modules/text_generation.py b/modules/text_generation.py index b7f6edf3..566c2f55 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -8,6 +8,7 @@ import traceback import numpy as np import torch import transformers +from transformers import LogitsProcessorList import modules.shared as shared from modules.callbacks import ( @@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) + processor = state.get('logits_processor', LogitsProcessorList([])) + # In case folks just pass in a processor by itself. + if type(processor) != LogitsProcessorList: + processor = LogitsProcessorList([processor]) + apply_extensions('logits_processor', processor, input_ids) + generate_params['logits_processor'] = processor + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: diff --git a/requirements.txt b/requirements.txt index 4e012c1e..24aa6d79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,10 +20,10 @@ tensorboard wandb transformers==4.30.2 git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524 -bitsandbytes==0.40.0; platform_system != "Windows" -https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl; platform_system == "Windows" -llama-cpp-python==0.1.70; platform_system != "Windows" -https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.70/llama_cpp_python-0.1.70-cp310-cp310-win_amd64.whl; platform_system == "Windows" +bitsandbytes==0.40.1.post1; platform_system != "Windows" +https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl; platform_system == "Windows" +llama-cpp-python==0.1.72; platform_system != "Windows" +https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.72/llama_cpp_python-0.1.72-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/jllllll/exllama/releases/download/0.0.6/exllama-0.0.6+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" diff --git a/server.py b/server.py index 56f94963..727e5894 100644 --- a/server.py +++ b/server.py @@ -204,7 +204,7 @@ def create_model_menus(): with gr.Row(): with gr.Column(): - shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "ExLlama_HF", "llama.cpp"], value=None) + shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None) with gr.Box(): with gr.Row(): with gr.Column(): @@ -223,11 +223,11 @@ def create_model_menus(): shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None") shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None") shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) - shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.') + shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb) - shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=8, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value) + shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value) with gr.Column(): shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) @@ -250,6 +250,7 @@ def create_model_menus(): shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.') + shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF is a wrapper that lets you use llama.cpp like a Transformers model, which means it can use the Transformers samplers. It works, but it\'s experimental and slow. Contributions are welcome.\n\nTo use it, make sure to first download oobabooga/llama-tokenizer under "Download custom model or LoRA".') with gr.Column(): with gr.Row(): @@ -848,7 +849,7 @@ def create_interface(): # Reset interface event shared.gradio['reset_interface'].click( set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then( - lambda: None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500); return []}') + lambda: None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500); return []}') shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')