llama_cpp: add speculative decoding

This commit is contained in:
jk 2025-01-16 03:13:40 +01:00
parent 2344366c9b
commit d1ce766616
5 changed files with 81 additions and 1 deletions

View File

@ -1,4 +1,7 @@
import re
import io
import itertools
from contextlib import redirect_stderr
from functools import partial
import numpy as np
@ -52,6 +55,66 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
return logits
class LlamaSmallModelDraft(llama_cpp_lib().llama_speculative.LlamaDraftModel):
"""
modified from https://gist.github.com/acasto/dce5f559fbe5da5ceed2c62db7afc262
Optimized draft model for speculative decoding.
Key Changes:
- Removed unnecessary prints and I/O overhead.
- Using greedy decoding parameters (top_k=1, top_p=1.0) if acceptable.
- Using itertools.islice to grab tokens in a single step rather than a loop.
- Consider adjusting n_ctx, n_batch, and model quantization to improve performance.
"""
def __init__(
self,
model_path: str,
num_pred_tokens: int = 5,
temperature: float = 0.0,
n_ctx: int = 2048,
n_batch: int = 512,
):
# Suppress unwanted stderr output during model load
f = io.StringIO()
with redirect_stderr(f):
self.draft_model = llama_cpp_lib().Llama(
model_path=model_path,
n_ctx=n_ctx,
n_batch=n_batch,
n_gpu_layers=-1,
verbose=False
)
self.num_pred_tokens = num_pred_tokens
self.temperature = temperature
def __call__(
self,
input_ids,
/,
**kwargs
):
# Convert numpy array to list for llama_cpp
input_tokens = input_ids.tolist()
# Generate tokens greedily or with minimal sampling complexity for speed
generated = itertools.islice(
self.draft_model.original_generate(
tokens=input_tokens,
temp=self.temperature,
top_k=1, # Greedy decoding
top_p=1.0, # Greedy decoding
reset=True, # Reset state for a fresh decode
),
self.num_pred_tokens
)
# Collect and convert to a numpy array
draft_tokens = np.fromiter(generated, dtype=np.intc, count=self.num_pred_tokens)
return draft_tokens
class LlamaCppModel:
def __init__(self):
self.initialized = False
@ -66,6 +129,7 @@ class LlamaCppModel:
Llama = llama_cpp_lib().Llama
LlamaCache = llama_cpp_lib().LlamaCache
LlamaPromptLookupDecoding = llama_cpp_lib().llama_speculative.LlamaPromptLookupDecoding
result = self()
cache_capacity = 0
@ -85,6 +149,13 @@ class LlamaCppModel:
else:
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
if shared.args.num_pred_tokens > 0 and shared.args.draft_model.endswith(".gguf"):
draft_model = LlamaSmallModelDraft(model_path=f"{shared.args.model_dir}/{shared.args.draft_model}", num_pred_tokens=shared.args.num_pred_tokens)
elif shared.args.num_pred_tokens > 0:
draft_model = LlamaPromptLookupDecoding(num_pred_tokens=shared.args.num_pred_tokens)
else:
draft_model = None
params = {
'model_path': str(path),
'n_ctx': shared.args.n_ctx,
@ -101,7 +172,8 @@ class LlamaCppModel:
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
'offload_kqv': not shared.args.no_offload_kqv,
'split_mode': 1 if not shared.args.row_split else 2,
'flash_attn': shared.args.flash_attn
'flash_attn': shared.args.flash_attn,
'draft_model': draft_model
}
if shared.args.cache_type != 'fp16':

View File

@ -48,6 +48,8 @@ loaders_and_params = OrderedDict({
'no_mmap',
'mlock',
'numa',
'draft_model',
'num_pred_tokens',
],
'llamacpp_HF': [
'n_gpu_layers',

View File

@ -133,6 +133,8 @@ group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (l
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.')
group.add_argument('--draft_model', type=str, default="", help='Draft Model for speculative decoding')
group.add_argument('--num_pred_tokens', type=int, default=0, help='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable')
group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.')
# ExLlamaV2

View File

@ -123,6 +123,8 @@ def list_model_elements():
'compute_dtype',
'quant_type',
'attention_sink_size',
'draft_model',
'num_pred_tokens',
'num_experts_per_token',
'tensorcores',
'load_in_8bit',

View File

@ -100,6 +100,8 @@ def create_ui():
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
shared.gradio['draft_model'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.args.draft_model, label="Draft Model", info='Draft Model for speculative decoding')
shared.gradio['num_pred_tokens'] = gr.Number(label="num_pred_tokens", value=shared.args.num_pred_tokens, precision=0, info='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable')
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
with gr.Column():