mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-27 20:43:19 +01:00
llama_cpp: add speculative decoding
This commit is contained in:
parent
2344366c9b
commit
d1ce766616
@ -1,4 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
|
import io
|
||||||
|
import itertools
|
||||||
|
from contextlib import redirect_stderr
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -52,6 +55,66 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaSmallModelDraft(llama_cpp_lib().llama_speculative.LlamaDraftModel):
|
||||||
|
"""
|
||||||
|
modified from https://gist.github.com/acasto/dce5f559fbe5da5ceed2c62db7afc262
|
||||||
|
|
||||||
|
Optimized draft model for speculative decoding.
|
||||||
|
|
||||||
|
Key Changes:
|
||||||
|
- Removed unnecessary prints and I/O overhead.
|
||||||
|
- Using greedy decoding parameters (top_k=1, top_p=1.0) if acceptable.
|
||||||
|
- Using itertools.islice to grab tokens in a single step rather than a loop.
|
||||||
|
- Consider adjusting n_ctx, n_batch, and model quantization to improve performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
num_pred_tokens: int = 5,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
n_ctx: int = 2048,
|
||||||
|
n_batch: int = 512,
|
||||||
|
):
|
||||||
|
# Suppress unwanted stderr output during model load
|
||||||
|
f = io.StringIO()
|
||||||
|
with redirect_stderr(f):
|
||||||
|
self.draft_model = llama_cpp_lib().Llama(
|
||||||
|
model_path=model_path,
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
n_batch=n_batch,
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
self.num_pred_tokens = num_pred_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
/,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Convert numpy array to list for llama_cpp
|
||||||
|
input_tokens = input_ids.tolist()
|
||||||
|
|
||||||
|
# Generate tokens greedily or with minimal sampling complexity for speed
|
||||||
|
generated = itertools.islice(
|
||||||
|
self.draft_model.original_generate(
|
||||||
|
tokens=input_tokens,
|
||||||
|
temp=self.temperature,
|
||||||
|
top_k=1, # Greedy decoding
|
||||||
|
top_p=1.0, # Greedy decoding
|
||||||
|
reset=True, # Reset state for a fresh decode
|
||||||
|
),
|
||||||
|
self.num_pred_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect and convert to a numpy array
|
||||||
|
draft_tokens = np.fromiter(generated, dtype=np.intc, count=self.num_pred_tokens)
|
||||||
|
return draft_tokens
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppModel:
|
class LlamaCppModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
@ -66,6 +129,7 @@ class LlamaCppModel:
|
|||||||
|
|
||||||
Llama = llama_cpp_lib().Llama
|
Llama = llama_cpp_lib().Llama
|
||||||
LlamaCache = llama_cpp_lib().LlamaCache
|
LlamaCache = llama_cpp_lib().LlamaCache
|
||||||
|
LlamaPromptLookupDecoding = llama_cpp_lib().llama_speculative.LlamaPromptLookupDecoding
|
||||||
|
|
||||||
result = self()
|
result = self()
|
||||||
cache_capacity = 0
|
cache_capacity = 0
|
||||||
@ -85,6 +149,13 @@ class LlamaCppModel:
|
|||||||
else:
|
else:
|
||||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
||||||
|
|
||||||
|
if shared.args.num_pred_tokens > 0 and shared.args.draft_model.endswith(".gguf"):
|
||||||
|
draft_model = LlamaSmallModelDraft(model_path=f"{shared.args.model_dir}/{shared.args.draft_model}", num_pred_tokens=shared.args.num_pred_tokens)
|
||||||
|
elif shared.args.num_pred_tokens > 0:
|
||||||
|
draft_model = LlamaPromptLookupDecoding(num_pred_tokens=shared.args.num_pred_tokens)
|
||||||
|
else:
|
||||||
|
draft_model = None
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'model_path': str(path),
|
'model_path': str(path),
|
||||||
'n_ctx': shared.args.n_ctx,
|
'n_ctx': shared.args.n_ctx,
|
||||||
@ -101,7 +172,8 @@ class LlamaCppModel:
|
|||||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
||||||
'offload_kqv': not shared.args.no_offload_kqv,
|
'offload_kqv': not shared.args.no_offload_kqv,
|
||||||
'split_mode': 1 if not shared.args.row_split else 2,
|
'split_mode': 1 if not shared.args.row_split else 2,
|
||||||
'flash_attn': shared.args.flash_attn
|
'flash_attn': shared.args.flash_attn,
|
||||||
|
'draft_model': draft_model
|
||||||
}
|
}
|
||||||
|
|
||||||
if shared.args.cache_type != 'fp16':
|
if shared.args.cache_type != 'fp16':
|
||||||
|
@ -48,6 +48,8 @@ loaders_and_params = OrderedDict({
|
|||||||
'no_mmap',
|
'no_mmap',
|
||||||
'mlock',
|
'mlock',
|
||||||
'numa',
|
'numa',
|
||||||
|
'draft_model',
|
||||||
|
'num_pred_tokens',
|
||||||
],
|
],
|
||||||
'llamacpp_HF': [
|
'llamacpp_HF': [
|
||||||
'n_gpu_layers',
|
'n_gpu_layers',
|
||||||
|
@ -133,6 +133,8 @@ group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (l
|
|||||||
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||||
group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.')
|
group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.')
|
||||||
|
group.add_argument('--draft_model', type=str, default="", help='Draft Model for speculative decoding')
|
||||||
|
group.add_argument('--num_pred_tokens', type=int, default=0, help='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable')
|
||||||
group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.')
|
group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.')
|
||||||
|
|
||||||
# ExLlamaV2
|
# ExLlamaV2
|
||||||
|
@ -123,6 +123,8 @@ def list_model_elements():
|
|||||||
'compute_dtype',
|
'compute_dtype',
|
||||||
'quant_type',
|
'quant_type',
|
||||||
'attention_sink_size',
|
'attention_sink_size',
|
||||||
|
'draft_model',
|
||||||
|
'num_pred_tokens',
|
||||||
'num_experts_per_token',
|
'num_experts_per_token',
|
||||||
'tensorcores',
|
'tensorcores',
|
||||||
'load_in_8bit',
|
'load_in_8bit',
|
||||||
|
@ -100,6 +100,8 @@ def create_ui():
|
|||||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
||||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
||||||
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
|
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
|
||||||
|
shared.gradio['draft_model'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.args.draft_model, label="Draft Model", info='Draft Model for speculative decoding')
|
||||||
|
shared.gradio['num_pred_tokens'] = gr.Number(label="num_pred_tokens", value=shared.args.num_pred_tokens, precision=0, info='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable')
|
||||||
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
Loading…
Reference in New Issue
Block a user