Update llama_cpp_python_hijack.py, fix llamacpp_hf

This commit is contained in:
oobabooga 2024-09-30 14:04:21 -07:00
parent 9ca0cd7749
commit 4d9ce586d3
2 changed files with 14 additions and 8 deletions

View File

@ -2,12 +2,12 @@ import importlib
import platform import platform
from typing import Sequence from typing import Sequence
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from modules import shared from modules import shared
from modules.cache_utils import process_llamacpp_cache from modules.cache_utils import process_llamacpp_cache
imported_module = None imported_module = None
@ -57,8 +57,6 @@ def eval_with_progress(self, tokens: Sequence[int]):
with tqdm to show prompt processing progress. with tqdm to show prompt processing progress.
""" """
assert self._ctx.ctx is not None
assert self._batch.batch is not None
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
if len(tokens) > self.n_batch: if len(tokens) > self.n_batch:
@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]):
if self.context_params.logits_all: if self.context_params.logits_all:
rows = n_tokens rows = n_tokens
cols = self._n_vocab cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols] logits = np.ctypeslib.as_array(
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits self._ctx.get_logits(), shape=(rows * cols,)
)
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
self.last_updated_index = n_past + n_tokens - 1
else: else:
rows = 1 rows = 1
cols = self._n_vocab cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols] logits = np.ctypeslib.as_array(
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits self._ctx.get_logits(), shape=(rows * cols,)
)
last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1)
self.scores[last_token_index, :] = logits.reshape(-1)
self.last_updated_index = last_token_index
# Update n_tokens # Update n_tokens
self.n_tokens += n_tokens self.n_tokens += n_tokens

View File

@ -127,7 +127,7 @@ class LlamacppHF(PreTrainedModel):
self.model.reset() self.model.reset()
self.model.eval(seq) self.model.eval(seq)
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device) logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device)
else: else:
self.model.reset() self.model.reset()
self.model.eval(seq) self.model.eval(seq)
@ -205,5 +205,6 @@ class LlamacppHF(PreTrainedModel):
Llama = llama_cpp_lib().Llama Llama = llama_cpp_lib().Llama
model = Llama(**params) model = Llama(**params)
model.last_updated_index = -1
return LlamacppHF(model, model_file) return LlamacppHF(model, model_file)