Merge pull request #3163 from oobabooga/dev

v1.2
This commit is contained in:
oobabooga 2023-07-16 02:43:18 -03:00 committed by GitHub
commit 9f08038864
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 399 additions and 18 deletions

View File

@ -62,7 +62,7 @@ class ModelDownloader:
is_lora = False is_lora = False
while True: while True:
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = self.s.get(url, timeout=20) r = self.s.get(url, timeout=10)
r.raise_for_status() r.raise_for_status()
content = r.content content = r.content
@ -136,7 +136,7 @@ class ModelDownloader:
if output_path.exists() and not start_from_scratch: if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely # Check if the file has already been downloaded completely
r = self.s.get(url, stream=True, timeout=20) r = self.s.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size: if output_path.stat().st_size >= total_size:
return return
@ -145,7 +145,7 @@ class ModelDownloader:
headers = {'Range': f'bytes={output_path.stat().st_size}-'} headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab' mode = 'ab'
with self.s.get(url, stream=True, headers=headers, timeout=20) as r: with self.s.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB block_size = 1024 * 1024 # 1MB

View File

@ -0,0 +1,215 @@
import gradio
import torch
from transformers import LogitsProcessor
import numpy as np
from modules import shared
params = {
'color_by_perplexity': False,
'color_by_probability': False,
'ppl_scale': 15.0, # No slider for this right now, because I don't think it really needs to be changed. Very large perplexity scores don't show up often.
#'probability_dropdown': False
}
class PerplexityLogits(LogitsProcessor):
def __init__(self, verbose=False):
self.generated_token_ids = []
self.selected_probs = []
self.top_token_ids_list = []
self.top_probs_list = []
self.perplexities_list = []
self.last_probs = None
self.verbose = verbose
def __call__(self, input_ids, scores):
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
log_probs = torch.nan_to_num(torch.log(probs))
entropy = -torch.sum(probs*log_probs)
entropy = entropy.cpu().numpy()
perplexity = round(float(np.exp(entropy)), 4)
self.perplexities_list.append(perplexity)
last_token_id = int(input_ids[0][-1].cpu().numpy().item())
# Store the generated tokens (not sure why this isn't accessible in the output endpoint!)
self.generated_token_ids.append(last_token_id)
# Get last probability, and add to the list if it wasn't there
if len(self.selected_probs) > 0:
# Is the selected token in the top tokens?
if self.verbose:
print(shared.tokenizer.decode(last_token_id))
print([shared.tokenizer.decode(token_id) for token_id in self.top_token_ids_list[-1]])
print(self.top_probs_list[-1])
if last_token_id in self.top_token_ids_list[-1]:
idx = self.top_token_ids_list[-1].index(last_token_id)
self.selected_probs.append(self.top_probs_list[-1][idx])
else:
self.top_token_ids_list[-1].append(last_token_id)
last_prob = round(float(self.last_probs[last_token_id]), 4)
self.top_probs_list[-1].append(last_prob)
self.selected_probs.append(last_prob)
else:
self.selected_probs.append(1.0) # Placeholder for the last token of the prompt
if self.verbose:
pplbar = "-"
if not np.isnan(perplexity):
pplbar = "*"*round(perplexity)
print(f"{last_token}\t{perplexity:.2f}\t{pplbar}")
# Get top 5 probabilities
top_tokens_and_probs = torch.topk(probs, 5)
top_probs = top_tokens_and_probs.values.cpu().numpy().astype(float).tolist()
top_token_ids = top_tokens_and_probs.indices.cpu().numpy().astype(int).tolist()
self.top_token_ids_list.append(top_token_ids)
self.top_probs_list.append(top_probs)
probs = probs.cpu().numpy().flatten()
self.last_probs = probs # Need to keep this as a reference for top probs
# Doesn't actually modify the logits!
return scores
# Stores the perplexity and top probabilities
ppl_logits_processor = None
def logits_processor_modifier(logits_processor_list, input_ids):
global ppl_logits_processor
ppl_logits_processor = PerplexityLogits()
logits_processor_list.append(ppl_logits_processor)
def output_modifier(text):
global ppl_logits_processor
# TODO: It's probably more efficient to do this above rather than modifying all these lists
# Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation
perplexities = ppl_logits_processor.perplexities_list[:-1]
top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1]
top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids] for top_token_ids in top_token_ids_list]
top_probs_list = ppl_logits_processor.top_probs_list[:-1]
# Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt
gen_token_ids = ppl_logits_processor.generated_token_ids[1:]
gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids]
sel_probs = ppl_logits_processor.selected_probs[1:]
end_part = '</span>' # Helps with finding the index after replacing part of the text.
in_code = False # Since the <span> tags mess up code blocks, avoid coloring while inside a code block, based on finding tokens with '`' in them
if params['color_by_probability'] and params['color_by_perplexity']:
i = 0
for token, prob, ppl, top_tokens, top_probs in zip(gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
if '`' in token:
in_code = not in_code
continue
if in_code:
continue
color = probability_perplexity_color_scale(prob, ppl)
if token in text[i:]:
text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
i += text[i:].find(end_part) + len(end_part)
elif params['color_by_perplexity']:
i = 0
for token, ppl, top_tokens, top_probs in zip(gen_tokens, perplexities, top_tokens_list, top_probs_list):
if '`' in token:
in_code = not in_code
continue
if in_code:
continue
color = perplexity_color_scale(ppl)
if token in text[i:]:
text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
i += text[i:].find(end_part) + len(end_part)
elif params['color_by_probability']:
i = 0
for token, prob, top_tokens, top_probs in zip(gen_tokens, sel_probs, top_tokens_list, top_probs_list):
if '`' in token:
in_code = not in_code
continue
if in_code:
continue
color = probability_color_scale(prob)
if token in text[i:]:
text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
i += text[i:].find(end_part) + len(end_part)
print('Average perplexity:', round(np.mean(perplexities), 4))
return text
# Green-yellow-red color scale
def probability_color_scale(prob):
rv = 0
gv = 0
if prob <= 0.5:
rv = 'ff'
gv = hex(int(255*prob*2))[2:]
if len(gv) < 2:
gv = '0'*(2 - len(gv)) + gv
else:
rv = hex(int(255 - 255*(prob - 0.5)*2))[2:]
gv = 'ff'
if len(rv) < 2:
rv = '0'*(2 - len(rv)) + rv
return rv + gv + '00'
# Red component only, white for 0 perplexity (sorry if you're not in dark mode)
def perplexity_color_scale(ppl):
value = hex(max(int(255.0 - params['ppl_scale']*(float(ppl)-1.0)), 0))[2:]
if len(value) < 2:
value = '0'*(2 - len(value)) + value
return 'ff' + value + value
# Green-yellow-red for probability and blue component for perplexity
def probability_perplexity_color_scale(prob, ppl):
rv = 0
gv = 0
bv = hex(min(max(int(params['ppl_scale']*(float(ppl)-1.0)), 0), 255))[2:]
if len(bv) < 2:
bv = '0'*(2 - len(bv)) + bv
if prob <= 0.5:
rv = 'ff'
gv = hex(int(255*prob*2))[2:]
if len(gv) < 2:
gv = '0'*(2 - len(gv)) + gv
else:
rv = hex(int(255 - 255*(prob - 0.5)*2))[2:]
gv = 'ff'
if len(rv) < 2:
rv = '0'*(2 - len(rv)) + rv
return rv + gv + bv
def add_color_html(token, color):
return f'<span style="color: #{color}">{token}</span>'
"""
# This is still very broken at the moment, needs CSS too but I'm not very good at CSS (and neither is GPT-4 apparently) so I still need to figure that out.
def add_dropdown_html(token, color, top_tokens, top_probs):
html = f'<span class="hoverable" style="color: #{color}">{token}<div class="dropdown"><table class="dropdown-content">'
for token, prob in zip(top_tokens, top_probs):
# TODO: Background color? Bold for selected token?
# Bigger issue: Why is there a newline after the first token, and the dropdown fails there?
# The HTML ends up like <p><span>word</span></p><div>...</div>,
# even though for all other tokens it shows up correctly.
row_color = probability_color_scale(prob)
html += f'<tr><td style="color: #{row_color}">{token}</td><td style="color: #{row_color}">{prob}</td></tr>'
html += '</table></div></span>'
return html
"""
def ui():
color_by_ppl_check = gradio.Checkbox(value=False, label="Color by perplexity", info="Higher perplexity is more red. If also showing probability, higher perplexity has more blue component.")
def update_color_by_ppl_check(x):
params.update({'color_by_perplexity': x})
color_by_ppl_check.change(update_color_by_ppl_check, color_by_ppl_check, None)
color_by_prob_check = gradio.Checkbox(value=False, label="Color by probability", info="Green-yellow-red linear scale, with 100% green, 50% yellow, 0% red.")
def update_color_by_prob_check(x):
params.update({'color_by_probability': x})
color_by_prob_check.change(update_color_by_prob_check, color_by_prob_check, None)
# Doesn't work yet...
"""
prob_dropdown_check = gradio.Checkbox(value=False, label="Probability dropdown")
def update_prob_dropdown_check(x):
params.update({'probability_dropdown': x})
prob_dropdown_check.change(update_prob_dropdown_check, prob_dropdown_check, None)
"""

View File

@ -29,6 +29,7 @@ class ExllamaHF(PreTrainedModel):
super().__init__(PretrainedConfig()) super().__init__(PretrainedConfig())
self.ex_config = config self.ex_config = config
self.ex_model = ExLlama(self.ex_config) self.ex_model = ExLlama(self.ex_config)
self.ex_cache = ExLlamaCache(self.ex_model)
self.generation_config = GenerationConfig() self.generation_config = GenerationConfig()
self.lora = None self.lora = None
@ -52,11 +53,20 @@ class ExllamaHF(PreTrainedModel):
labels = kwargs.get('labels', None) labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist() seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
if cache is None:
cache = ExLlamaCache(self.ex_model)
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device) if labels is None:
if cache is None:
self.ex_cache.current_seq_len = 0
cache = self.ex_cache
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
else:
if cache is None:
self.ex_cache.current_seq_len = 0
cache = self.ex_cache
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), cache, last_id_only=False, lora=self.lora)
loss = None loss = None
if labels is not None: if labels is not None:
@ -71,7 +81,7 @@ class ExllamaHF(PreTrainedModel):
shift_labels = shift_labels.to(shift_logits.device) shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None) return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):

View File

@ -106,15 +106,23 @@ def _apply_history_modifier_extensions(history):
return history return history
# Extension functions that override the default tokenizer output - currently only the first one will work # Extension functions that override the default tokenizer output - The order of execution is not defined
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, function_name): if hasattr(extension, function_name):
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds) prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds return prompt, input_ids, input_embeds
# Allow extensions to add their own logits processors to the stack being run.
# Each extension would call `processor_list.append({their LogitsProcessor}())`.
def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
for extension, _ in iterator():
if hasattr(extension, function_name):
getattr(extension, function_name)(processor_list, input_ids)
# Get prompt length in tokens after applying extension functions which override the default tokenizer output # Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work # currently only the first one will work
def _apply_custom_tokenized_length(prompt): def _apply_custom_tokenized_length(prompt):
@ -183,6 +191,7 @@ EXTENSION_MAP = {
"history": _apply_history_modifier_extensions, "history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
"input_hijack": _apply_input_hijack, "input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply, "custom_generate_reply": _apply_custom_generate_reply,

103
modules/llamacpp_hf.py Normal file
View File

@ -0,0 +1,103 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from llama_cpp import Llama
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
class LlamacppHF(PreTrainedModel):
def __init__(self, model):
super().__init__(PretrainedConfig())
self.model = model
self.generation_config = GenerationConfig()
self.cache = None
def _validate_model_class(self):
pass
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
pass
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}
@property
def device(self) -> torch.device:
return torch.device(0)
def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
# Make the forward call
seq_tensor = torch.tensor(seq)
self.cache = seq_tensor
if labels is None:
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
self.model.reset()
self.model.eval(seq)
else:
self.model.eval([seq[-1]])
logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device)
else:
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.eval_logits)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
# Based on transformers/models/llama/modeling_llama.py
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None, loss=loss)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
if path.is_file():
model_file = path
else:
model_file = list(path.glob('*ggml*.bin'))[0]
logger.info(f"llama.cpp weights detected: {model_file}\n")
params = {
'model_path': str(model_file),
'n_ctx': shared.args.n_ctx,
'seed': int(shared.args.llama_cpp_seed),
'n_threads': shared.args.threads or None,
'n_batch': shared.args.n_batch,
'use_mmap': not shared.args.no_mmap,
'use_mlock': shared.args.mlock,
'low_vram': shared.args.low_vram,
'n_gpu_layers': shared.args.n_gpu_layers,
'logits_all': True,
}
model = Llama(**params)
return LlamacppHF(model)

View File

@ -38,6 +38,17 @@ loaders_and_params = {
'mlock', 'mlock',
'llama_cpp_seed', 'llama_cpp_seed',
], ],
'llamacpp_HF': [
'n_ctx',
'n_gpu_layers',
'n_batch',
'threads',
'no_mmap',
'low_vram',
'mlock',
'llama_cpp_seed',
'llamacpp_HF_info',
],
'Transformers': [ 'Transformers': [
'cpu_memory', 'cpu_memory',
'gpu_memory', 'gpu_memory',

View File

@ -55,6 +55,7 @@ def load_model(model_name, loader=None):
'AutoGPTQ': AutoGPTQ_loader, 'AutoGPTQ': AutoGPTQ_loader,
'GPTQ-for-LLaMa': GPTQ_loader, 'GPTQ-for-LLaMa': GPTQ_loader,
'llama.cpp': llamacpp_loader, 'llama.cpp': llamacpp_loader,
'llamacpp_HF': llamacpp_HF_loader,
'FlexGen': flexgen_loader, 'FlexGen': flexgen_loader,
'RWKV': RWKV_loader, 'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader, 'ExLlama': ExLlama_loader,
@ -268,6 +269,27 @@ def llamacpp_loader(model_name):
return model, tokenizer return model, tokenizer
def llamacpp_HF_loader(model_name):
from modules.llamacpp_hf import LlamacppHF
for fname in ["oobabooga_llama-tokenizer", "llama-tokenizer"]:
path = Path(f'{shared.args.model_dir}/{fname}')
if path.exists():
break
else:
logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.")
return None, None
tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=shared.args.trust_remote_code,
use_fast=False
)
model = LlamacppHF.from_pretrained(model_name)
return model, tokenizer
def GPTQ_loader(model_name): def GPTQ_loader(model_name):
# Monkey patch # Monkey patch

View File

@ -214,6 +214,8 @@ def fix_loader_name(name):
name = name.lower() name = name.lower()
if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']: if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']:
return 'llama.cpp' return 'llama.cpp'
if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']:
return 'llamacpp_HF'
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
return 'Transformers' return 'Transformers'
elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']: elif name in ['autogptq', 'auto-gptq', 'auto_gptq', 'auto gptq']:

View File

@ -8,6 +8,7 @@ import traceback
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from transformers import LogitsProcessorList
import modules.shared as shared import modules.shared as shared
from modules.callbacks import ( from modules.callbacks import (
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
processor = state.get('logits_processor', LogitsProcessorList([]))
# In case folks just pass in a processor by itself.
if type(processor) != LogitsProcessorList:
processor = LogitsProcessorList([processor])
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor
t0 = time.time() t0 = time.time()
try: try:
if not is_chat and not shared.is_seq2seq: if not is_chat and not shared.is_seq2seq:

View File

@ -20,10 +20,10 @@ tensorboard
wandb wandb
transformers==4.30.2 transformers==4.30.2
git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524 git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524
bitsandbytes==0.40.0; platform_system != "Windows" bitsandbytes==0.40.1.post1; platform_system != "Windows"
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl; platform_system == "Windows"
llama-cpp-python==0.1.70; platform_system != "Windows" llama-cpp-python==0.1.72; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.70/llama_cpp_python-0.1.70-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.72/llama_cpp_python-0.1.72-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/jllllll/exllama/releases/download/0.0.6/exllama-0.0.6+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/jllllll/exllama/releases/download/0.0.6/exllama-0.0.6+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"

View File

@ -204,7 +204,7 @@ def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "ExLlama_HF", "llama.cpp"], value=None) shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "ExLlama_HF", "ExLlama", "AutoGPTQ", "GPTQ-for-LLaMa", "llama.cpp", "llamacpp_HF"], value=None)
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -223,11 +223,11 @@ def create_model_menus():
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None") shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None") shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0) shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.') shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len) shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb) shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb)
shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=8, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Same as above. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value) shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value)
with gr.Column(): with gr.Column():
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
@ -250,6 +250,7 @@ def create_model_menus():
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.') shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
shared.gradio['llamacpp_HF_info'] = gr.Markdown('llamacpp_HF is a wrapper that lets you use llama.cpp like a Transformers model, which means it can use the Transformers samplers. It works, but it\'s experimental and slow. Contributions are welcome.\n\nTo use it, make sure to first download oobabooga/llama-tokenizer under "Download custom model or LoRA".')
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
@ -848,7 +849,7 @@ def create_interface():
# Reset interface event # Reset interface event
shared.gradio['reset_interface'].click( shared.gradio['reset_interface'].click(
set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then( set_interface_arguments, gradio('interface_modes_menu', 'extensions_menu', 'bool_menu'), None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}') lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;padding-top:20%;margin:0;height:100vh;color:lightgray;text-align:center;background:var(--body-background-fill)">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}') shared.gradio['toggle_dark_mode'].click(lambda: None, None, None, _js='() => {document.getElementsByTagName("body")[0].classList.toggle("dark")}')