From 86c320ab5aee12864619837e67d6f06325ca62d9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 7 Feb 2024 21:40:58 -0800 Subject: [PATCH] llama.cpp: add a progress bar for prompt evaluation --- modules/llama_cpp_python_hijack.py | 63 ++++++++++++++++++++++++++++++ modules/llamacpp_hf.py | 2 +- modules/llamacpp_model.py | 2 +- 3 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 modules/llama_cpp_python_hijack.py diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py new file mode 100644 index 00000000..e63d9977 --- /dev/null +++ b/modules/llama_cpp_python_hijack.py @@ -0,0 +1,63 @@ +from typing import Sequence + +from tqdm import tqdm + +try: + import llama_cpp +except: + llama_cpp = None + +try: + import llama_cpp_cuda +except: + llama_cpp_cuda = None + +try: + import llama_cpp_cuda_tensorcores +except: + llama_cpp_cuda_tensorcores = None + + +def eval_with_progress(self, tokens: Sequence[int]): + """ + A copy of + + https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py + + with tqdm to show prompt processing progress. + """ + assert self._ctx.ctx is not None + assert self._batch.batch is not None + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + + if len(tokens) > 1: + progress_bar = tqdm(range(0, len(tokens), self.n_batch), desc="Prompt evaluation", leave=False) + else: + progress_bar = range(0, len(tokens), self.n_batch) + + for i in progress_bar: + batch = tokens[i : min(len(tokens), i + self.n_batch)] + n_past = self.n_tokens + n_tokens = len(batch) + self._batch.set_batch( + batch=batch, n_past=n_past, logits_all=self.context_params.logits_all + ) + self._ctx.decode(self._batch) + # Save tokens + self.input_ids[n_past : n_past + n_tokens] = batch + # Save logits + rows = n_tokens + cols = self._n_vocab + offset = ( + 0 if self.context_params.logits_all else n_tokens - 1 + ) # NOTE: Only save the last token logits if logits_all is False + self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[ + : + ] = self._ctx.get_logits()[offset * cols : rows * cols] + # Update n_tokens + self.n_tokens += n_tokens + + +for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]: + if lib is not None: + lib.Llama.eval = eval_with_progress diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 4726669b..e7e86e0b 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -7,7 +7,7 @@ from torch.nn import CrossEntropyLoss from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast -from modules import RoPE, shared +from modules import RoPE, llama_cpp_python_hijack, shared from modules.logging_colors import logger try: diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 7c405a4b..8bc9b7cb 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -4,7 +4,7 @@ from functools import partial import numpy as np import torch -from modules import RoPE, shared +from modules import RoPE, llama_cpp_python_hijack, shared from modules.callbacks import Iteratorize from modules.logging_colors import logger from modules.text_generation import get_max_prompt_length