Update llama_cpp_python_hijack.py, fix llamacpp_hf

This commit is contained in:
oobabooga 2024-09-30 14:04:21 -07:00
parent 9ca0cd7749
commit 4d9ce586d3
2 changed files with 14 additions and 8 deletions

View File

@ -2,12 +2,12 @@ import importlib
import platform
from typing import Sequence
import numpy as np
from tqdm import tqdm
from modules import shared
from modules.cache_utils import process_llamacpp_cache
imported_module = None
@ -57,8 +57,6 @@ def eval_with_progress(self, tokens: Sequence[int]):
with tqdm to show prompt processing progress.
"""
assert self._ctx.ctx is not None
assert self._batch.batch is not None
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
if len(tokens) > self.n_batch:
@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]):
if self.context_params.logits_all:
rows = n_tokens
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
)
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
self.last_updated_index = n_past + n_tokens - 1
else:
rows = 1
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
)
last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1)
self.scores[last_token_index, :] = logits.reshape(-1)
self.last_updated_index = last_token_index
# Update n_tokens
self.n_tokens += n_tokens

View File

@ -127,7 +127,7 @@ class LlamacppHF(PreTrainedModel):
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device)
else:
self.model.reset()
self.model.eval(seq)
@ -205,5 +205,6 @@ class LlamacppHF(PreTrainedModel):
Llama = llama_cpp_lib().Llama
model = Llama(**params)
model.last_updated_index = -1
return LlamacppHF(model, model_file)