mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 16:38:21 +01:00
Update llama_cpp_python_hijack.py, fix llamacpp_hf
This commit is contained in:
parent
9ca0cd7749
commit
4d9ce586d3
@ -2,12 +2,12 @@ import importlib
|
||||
import platform
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared
|
||||
from modules.cache_utils import process_llamacpp_cache
|
||||
|
||||
|
||||
imported_module = None
|
||||
|
||||
|
||||
@ -57,8 +57,6 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
||||
|
||||
with tqdm to show prompt processing progress.
|
||||
"""
|
||||
assert self._ctx.ctx is not None
|
||||
assert self._batch.batch is not None
|
||||
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||
|
||||
if len(tokens) > self.n_batch:
|
||||
@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
||||
if self.context_params.logits_all:
|
||||
rows = n_tokens
|
||||
cols = self._n_vocab
|
||||
logits = self._ctx.get_logits()[: rows * cols]
|
||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
|
||||
logits = np.ctypeslib.as_array(
|
||||
self._ctx.get_logits(), shape=(rows * cols,)
|
||||
)
|
||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
|
||||
self.last_updated_index = n_past + n_tokens - 1
|
||||
else:
|
||||
rows = 1
|
||||
cols = self._n_vocab
|
||||
logits = self._ctx.get_logits()[: rows * cols]
|
||||
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
|
||||
logits = np.ctypeslib.as_array(
|
||||
self._ctx.get_logits(), shape=(rows * cols,)
|
||||
)
|
||||
last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1)
|
||||
self.scores[last_token_index, :] = logits.reshape(-1)
|
||||
self.last_updated_index = last_token_index
|
||||
# Update n_tokens
|
||||
self.n_tokens += n_tokens
|
||||
|
||||
|
@ -127,7 +127,7 @@ class LlamacppHF(PreTrainedModel):
|
||||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
|
||||
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
|
||||
logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device)
|
||||
else:
|
||||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
@ -205,5 +205,6 @@ class LlamacppHF(PreTrainedModel):
|
||||
|
||||
Llama = llama_cpp_lib().Llama
|
||||
model = Llama(**params)
|
||||
model.last_updated_index = -1
|
||||
|
||||
return LlamacppHF(model, model_file)
|
||||
|
Loading…
Reference in New Issue
Block a user