mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 00:46:53 +01:00
Update llama_cpp_python_hijack.py, fix llamacpp_hf
This commit is contained in:
parent
9ca0cd7749
commit
4d9ce586d3
@ -2,12 +2,12 @@ import importlib
|
|||||||
import platform
|
import platform
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.cache_utils import process_llamacpp_cache
|
from modules.cache_utils import process_llamacpp_cache
|
||||||
|
|
||||||
|
|
||||||
imported_module = None
|
imported_module = None
|
||||||
|
|
||||||
|
|
||||||
@ -57,8 +57,6 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||||||
|
|
||||||
with tqdm to show prompt processing progress.
|
with tqdm to show prompt processing progress.
|
||||||
"""
|
"""
|
||||||
assert self._ctx.ctx is not None
|
|
||||||
assert self._batch.batch is not None
|
|
||||||
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||||
|
|
||||||
if len(tokens) > self.n_batch:
|
if len(tokens) > self.n_batch:
|
||||||
@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||||||
if self.context_params.logits_all:
|
if self.context_params.logits_all:
|
||||||
rows = n_tokens
|
rows = n_tokens
|
||||||
cols = self._n_vocab
|
cols = self._n_vocab
|
||||||
logits = self._ctx.get_logits()[: rows * cols]
|
logits = np.ctypeslib.as_array(
|
||||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
|
self._ctx.get_logits(), shape=(rows * cols,)
|
||||||
|
)
|
||||||
|
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
|
||||||
|
self.last_updated_index = n_past + n_tokens - 1
|
||||||
else:
|
else:
|
||||||
rows = 1
|
rows = 1
|
||||||
cols = self._n_vocab
|
cols = self._n_vocab
|
||||||
logits = self._ctx.get_logits()[: rows * cols]
|
logits = np.ctypeslib.as_array(
|
||||||
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
|
self._ctx.get_logits(), shape=(rows * cols,)
|
||||||
|
)
|
||||||
|
last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1)
|
||||||
|
self.scores[last_token_index, :] = logits.reshape(-1)
|
||||||
|
self.last_updated_index = last_token_index
|
||||||
# Update n_tokens
|
# Update n_tokens
|
||||||
self.n_tokens += n_tokens
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
self.model.reset()
|
self.model.reset()
|
||||||
self.model.eval(seq)
|
self.model.eval(seq)
|
||||||
|
|
||||||
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
|
logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
self.model.reset()
|
self.model.reset()
|
||||||
self.model.eval(seq)
|
self.model.eval(seq)
|
||||||
@ -205,5 +205,6 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
|
|
||||||
Llama = llama_cpp_lib().Llama
|
Llama = llama_cpp_lib().Llama
|
||||||
model = Llama(**params)
|
model = Llama(**params)
|
||||||
|
model.last_updated_index = -1
|
||||||
|
|
||||||
return LlamacppHF(model, model_file)
|
return LlamacppHF(model, model_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user