mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
llama.cpp: add a progress bar for prompt evaluation
This commit is contained in:
parent
acea6a6669
commit
86c320ab5a
63
modules/llama_cpp_python_hijack.py
Normal file
63
modules/llama_cpp_python_hijack.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import llama_cpp
|
||||||
|
except:
|
||||||
|
llama_cpp = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import llama_cpp_cuda
|
||||||
|
except:
|
||||||
|
llama_cpp_cuda = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import llama_cpp_cuda_tensorcores
|
||||||
|
except:
|
||||||
|
llama_cpp_cuda_tensorcores = None
|
||||||
|
|
||||||
|
|
||||||
|
def eval_with_progress(self, tokens: Sequence[int]):
|
||||||
|
"""
|
||||||
|
A copy of
|
||||||
|
|
||||||
|
https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py
|
||||||
|
|
||||||
|
with tqdm to show prompt processing progress.
|
||||||
|
"""
|
||||||
|
assert self._ctx.ctx is not None
|
||||||
|
assert self._batch.batch is not None
|
||||||
|
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||||
|
|
||||||
|
if len(tokens) > 1:
|
||||||
|
progress_bar = tqdm(range(0, len(tokens), self.n_batch), desc="Prompt evaluation", leave=False)
|
||||||
|
else:
|
||||||
|
progress_bar = range(0, len(tokens), self.n_batch)
|
||||||
|
|
||||||
|
for i in progress_bar:
|
||||||
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
|
n_past = self.n_tokens
|
||||||
|
n_tokens = len(batch)
|
||||||
|
self._batch.set_batch(
|
||||||
|
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
|
||||||
|
)
|
||||||
|
self._ctx.decode(self._batch)
|
||||||
|
# Save tokens
|
||||||
|
self.input_ids[n_past : n_past + n_tokens] = batch
|
||||||
|
# Save logits
|
||||||
|
rows = n_tokens
|
||||||
|
cols = self._n_vocab
|
||||||
|
offset = (
|
||||||
|
0 if self.context_params.logits_all else n_tokens - 1
|
||||||
|
) # NOTE: Only save the last token logits if logits_all is False
|
||||||
|
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
|
||||||
|
:
|
||||||
|
] = self._ctx.get_logits()[offset * cols : rows * cols]
|
||||||
|
# Update n_tokens
|
||||||
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
|
|
||||||
|
for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]:
|
||||||
|
if lib is not None:
|
||||||
|
lib.Llama.eval = eval_with_progress
|
@ -7,7 +7,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
from modules import RoPE, shared
|
from modules import RoPE, llama_cpp_python_hijack, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -4,7 +4,7 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import RoPE, shared
|
from modules import RoPE, llama_cpp_python_hijack, shared
|
||||||
from modules.callbacks import Iteratorize
|
from modules.callbacks import Iteratorize
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import get_max_prompt_length
|
from modules.text_generation import get_max_prompt_length
|
||||||
|
Loading…
Reference in New Issue
Block a user