diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index c79755e4..2de0ad00 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -1,4 +1,7 @@ import re +import io +import itertools +from contextlib import redirect_stderr from functools import partial import numpy as np @@ -52,6 +55,66 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits): return logits +class LlamaSmallModelDraft(llama_cpp_lib().llama_speculative.LlamaDraftModel): + """ + modified from https://gist.github.com/acasto/dce5f559fbe5da5ceed2c62db7afc262 + + Optimized draft model for speculative decoding. + + Key Changes: + - Removed unnecessary prints and I/O overhead. + - Using greedy decoding parameters (top_k=1, top_p=1.0) if acceptable. + - Using itertools.islice to grab tokens in a single step rather than a loop. + - Consider adjusting n_ctx, n_batch, and model quantization to improve performance. + """ + + def __init__( + self, + model_path: str, + num_pred_tokens: int = 5, + temperature: float = 0.0, + n_ctx: int = 2048, + n_batch: int = 512, + ): + # Suppress unwanted stderr output during model load + f = io.StringIO() + with redirect_stderr(f): + self.draft_model = llama_cpp_lib().Llama( + model_path=model_path, + n_ctx=n_ctx, + n_batch=n_batch, + n_gpu_layers=-1, + verbose=False + ) + self.num_pred_tokens = num_pred_tokens + self.temperature = temperature + + def __call__( + self, + input_ids, + /, + **kwargs + ): + # Convert numpy array to list for llama_cpp + input_tokens = input_ids.tolist() + + # Generate tokens greedily or with minimal sampling complexity for speed + generated = itertools.islice( + self.draft_model.original_generate( + tokens=input_tokens, + temp=self.temperature, + top_k=1, # Greedy decoding + top_p=1.0, # Greedy decoding + reset=True, # Reset state for a fresh decode + ), + self.num_pred_tokens + ) + + # Collect and convert to a numpy array + draft_tokens = np.fromiter(generated, dtype=np.intc, count=self.num_pred_tokens) + return draft_tokens + + class LlamaCppModel: def __init__(self): self.initialized = False @@ -66,6 +129,7 @@ class LlamaCppModel: Llama = llama_cpp_lib().Llama LlamaCache = llama_cpp_lib().LlamaCache + LlamaPromptLookupDecoding = llama_cpp_lib().llama_speculative.LlamaPromptLookupDecoding result = self() cache_capacity = 0 @@ -85,6 +149,13 @@ class LlamaCppModel: else: tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] + if shared.args.num_pred_tokens > 0 and shared.args.draft_model.endswith(".gguf"): + draft_model = LlamaSmallModelDraft(model_path=f"{shared.args.model_dir}/{shared.args.draft_model}", num_pred_tokens=shared.args.num_pred_tokens) + elif shared.args.num_pred_tokens > 0: + draft_model = LlamaPromptLookupDecoding(num_pred_tokens=shared.args.num_pred_tokens) + else: + draft_model = None + params = { 'model_path': str(path), 'n_ctx': shared.args.n_ctx, @@ -101,7 +172,8 @@ class LlamaCppModel: 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, 'offload_kqv': not shared.args.no_offload_kqv, 'split_mode': 1 if not shared.args.row_split else 2, - 'flash_attn': shared.args.flash_attn + 'flash_attn': shared.args.flash_attn, + 'draft_model': draft_model } if shared.args.cache_type != 'fp16': diff --git a/modules/loaders.py b/modules/loaders.py index cd864e40..1801809f 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -48,6 +48,8 @@ loaders_and_params = OrderedDict({ 'no_mmap', 'mlock', 'numa', + 'draft_model', + 'num_pred_tokens', ], 'llamacpp_HF': [ 'n_gpu_layers', diff --git a/modules/shared.py b/modules/shared.py index f1e12673..a3b83b59 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -133,6 +133,8 @@ group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (l group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.') +group.add_argument('--draft_model', type=str, default="", help='Draft Model for speculative decoding') +group.add_argument('--num_pred_tokens', type=int, default=0, help='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable') group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.') # ExLlamaV2 diff --git a/modules/ui.py b/modules/ui.py index df948a14..68164066 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -123,6 +123,8 @@ def list_model_elements(): 'compute_dtype', 'quant_type', 'attention_sink_size', + 'draft_model', + 'num_pred_tokens', 'num_experts_per_token', 'tensorcores', 'load_in_8bit', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index d5116938..63fe0ad3 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -100,6 +100,8 @@ def create_ui(): shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.') shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.') shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.') + shared.gradio['draft_model'] = gr.Dropdown(choices=utils.get_available_models(), value=shared.args.draft_model, label="Draft Model", info='Draft Model for speculative decoding') + shared.gradio['num_pred_tokens'] = gr.Number(label="num_pred_tokens", value=shared.args.num_pred_tokens, precision=0, info='Number of tokens to predict using prompt lookup decoding or speculative decoding. Set to 0 to disable') shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.') with gr.Column():