From 4d9ce586d31450d2ed692b1e7a72dbcdfd56e670 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 30 Sep 2024 14:04:21 -0700 Subject: [PATCH] Update llama_cpp_python_hijack.py, fix llamacpp_hf --- modules/llama_cpp_python_hijack.py | 19 ++++++++++++------- modules/llamacpp_hf.py | 3 ++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py index 3d42b2d7..2a9c10da 100644 --- a/modules/llama_cpp_python_hijack.py +++ b/modules/llama_cpp_python_hijack.py @@ -2,12 +2,12 @@ import importlib import platform from typing import Sequence +import numpy as np from tqdm import tqdm from modules import shared from modules.cache_utils import process_llamacpp_cache - imported_module = None @@ -57,8 +57,6 @@ def eval_with_progress(self, tokens: Sequence[int]): with tqdm to show prompt processing progress. """ - assert self._ctx.ctx is not None - assert self._batch.batch is not None self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) if len(tokens) > self.n_batch: @@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]): if self.context_params.logits_all: rows = n_tokens cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] - self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits + logits = np.ctypeslib.as_array( + self._ctx.get_logits(), shape=(rows * cols,) + ) + self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits + self.last_updated_index = n_past + n_tokens - 1 else: rows = 1 cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] - self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits + logits = np.ctypeslib.as_array( + self._ctx.get_logits(), shape=(rows * cols,) + ) + last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1) + self.scores[last_token_index, :] = logits.reshape(-1) + self.last_updated_index = last_token_index # Update n_tokens self.n_tokens += n_tokens diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 327e3a7b..6611a7c1 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -127,7 +127,7 @@ class LlamacppHF(PreTrainedModel): self.model.reset() self.model.eval(seq) - logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) + logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device) else: self.model.reset() self.model.eval(seq) @@ -205,5 +205,6 @@ class LlamacppHF(PreTrainedModel): Llama = llama_cpp_lib().Llama model = Llama(**params) + model.last_updated_index = -1 return LlamacppHF(model, model_file)