mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-27 04:23:21 +01:00
llama_cpp: add speculative decoding
This commit is contained in:
parent
2344366c9b
commit
d1ce766616
@ -1,4 +1,7 @@
|
||||
import re
|
||||
import io
|
||||
import itertools
|
||||
from contextlib import redirect_stderr
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
@ -52,6 +55,66 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
||||
return logits
|
||||
|
||||
|
||||
class LlamaSmallModelDraft(llama_cpp_lib().llama_speculative.LlamaDraftModel):
|
||||
"""
|
||||
modified from https://gist.github.com/acasto/dce5f559fbe5da5ceed2c62db7afc262
|
||||
|
||||
Optimized draft model for speculative decoding.
|
||||
|
||||
Key Changes:
|
||||
- Removed unnecessary prints and I/O overhead.
|
||||
- Using greedy decoding parameters (top_k=1, top_p=1.0) if acceptable.
|
||||
- Using itertools.islice to grab tokens in a single step rather than a loop.
|
||||
- Consider adjusting n_ctx, n_batch, and model quantization to improve performance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
num_pred_tokens: int = 5,
|
||||
temperature: float = 0.0,
|
||||
n_ctx: int = 2048,
|
||||
n_batch: int = 512,
|
||||
):
|
||||
# Suppress unwanted stderr output during model load
|
||||
f = io.StringIO()
|
||||
with redirect_stderr(f):
|
||||
self.draft_model = llama_cpp_lib().Llama(
|
||||
model_path=model_path,
|
||||
n_ctx=n_ctx,
|
||||
n_batch=n_batch,
|
||||
n_gpu_layers=-1,
|
||||
verbose=False
|
||||
)
|
||||
self.num_pred_tokens = num_pred_tokens
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
/,
|
||||
**kwargs
|
||||
):
|
||||
# Convert numpy array to list for llama_cpp
|
||||
input_tokens = input_ids.tolist()
|
||||
|
||||
# Generate tokens greedily or with minimal sampling complexity for speed
|
||||
generated = itertools.islice(
|
||||
self.draft_model.original_generate(
|
||||
tokens=input_tokens,
|
||||
temp=self.temperature,
|
||||
top_k=1, # Greedy decoding
|
||||
top_p=1.0, # Greedy decoding
|
||||
reset=True, # Reset state for a fresh decode
|
||||
),
|
||||
self.num_pred_tokens
|
||||
)
|
||||
|
||||
# Collect and convert to a numpy array
|
||||
draft_tokens = np.fromiter(generated, dtype=np.intc, count=self.num_pred_tokens)
|
||||
return draft_tokens
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
@ -66,6 +129,7 @@ class LlamaCppModel:
|
||||
|
||||
Llama = llama_cpp_lib().Llama
|
||||
LlamaCache = llama_cpp_lib().LlamaCache
|
||||
LlamaPromptLookupDecoding = llama_cpp_lib().llama_speculative.LlamaPromptLookupDecoding
|
||||
|
||||
result = self()
|
||||
cache_capacity = 0
|
||||
@ -85,6 +149,13 @@ class LlamaCppModel:
|
||||
else:
|
||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
||||
|
||||
if shared.args.num_pred_tokens > 0 and shared.args.draft_model.endswith(".gguf"):
|
||||
draft_model = LlamaSmallModelDraft(model_path=f"{shared.args.model_dir}/{shared.args.draft_model}", num_pred_tokens=shared.args.num_pred_tokens)
|
||||
elif shared.args.num_pred_tokens > 0:
|
||||
draft_model = LlamaPromptLookupDecoding(num_pred_tokens=shared.args.num_pred_tokens)
|
||||
else:
|
||||
draft_model = None
|
||||
|
||||
params = {
|
||||
'model_path': str(path),
|
||||
'n_ctx': shared.args.n_ctx,
|
||||
@ -101,7 +172,8 @@ class LlamaCppModel:
|
||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||
'offload_kqv': not shared.args.no_offload_kqv,
|
||||
'split_mode': 1 if not shared.args.row_split else 2,
|
||||
'flash_attn': shared.args.flash_attn
|
||||
'flash_attn': shared.args.flash_attn,
|
||||
'draft_model': draft_model
|
||||
}
|
||||
|
||||
if shared.args.cache_type != 'fp16':
|
||||
|
@ -48,6 +48,8 @@ loaders_and_params = OrderedDict({
|
||||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'draft_model',
|
||||
'num_pred_tokens',
|
||||
],
|
||||
'llamacpp_HF': [
|
||||
'n_gpu_layers',
|
||||
|
@ -133,6 +133,8 @@ group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (l
|
||||
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||
group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.')
|
||||
group.add_argument('--draft_model', type=str, default="", help='Draft Model for speculative decoding')
|
||||
group.add_argument('--num_pred_tokens', type=int, default=0, help='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable')
|
||||
group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.')
|
||||
|
||||
# ExLlamaV2
|
||||
|
@ -123,6 +123,8 @@ def list_model_elements():
|
||||
'compute_dtype',
|
||||
'quant_type',
|
||||
'attention_sink_size',
|
||||
'draft_model',
|
||||
'num_pred_tokens',
|
||||
'num_experts_per_token',
|
||||
'tensorcores',
|
||||
'load_in_8bit',
|
||||
|
@ -100,6 +100,8 @@ def create_ui():
|
||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
||||
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
|
||||
shared.gradio['draft_model'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.args.draft_model, label="Draft Model", info='Draft Model for speculative decoding')
|
||||
shared.gradio['num_pred_tokens'] = gr.Number(label="num_pred_tokens", value=shared.args.num_pred_tokens, precision=0, info='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable')
|
||||
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
||||
|
||||
with gr.Column():
|
||||
|
Loading…
Reference in New Issue
Block a user