mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Add StreamingLLM for llamacpp & llamacpp_HF (2nd attempt) (#5669)
This commit is contained in:
parent
9271e80914
commit
afb51bd5d6
108
modules/cache_utils.py
Normal file
108
modules/cache_utils.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
|
def process_llamacpp_cache(model, new_sequence, past_sequence):
|
||||||
|
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence)
|
||||||
|
overlap_length = i2 - i1 + 1
|
||||||
|
|
||||||
|
# Do StreamingLLM if i1 > 0 (ie the longest common subsequence is not a prefix)
|
||||||
|
# and the overlap length is sufficiently long.
|
||||||
|
if i1 > 0 and overlap_length > 0.2 * len(new_sequence):
|
||||||
|
|
||||||
|
new_sequence = torch.tensor(new_sequence)
|
||||||
|
past_sequence = torch.tensor(past_sequence)
|
||||||
|
|
||||||
|
prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1])
|
||||||
|
sink_length = prefix_length
|
||||||
|
if sink_length < shared.args.attention_sink_size:
|
||||||
|
sink_length = shared.args.attention_sink_size
|
||||||
|
|
||||||
|
removed_length = i1 - sink_length
|
||||||
|
|
||||||
|
matching_prefix = past_sequence[:prefix_length]
|
||||||
|
removed_chunk = past_sequence[sink_length:i1]
|
||||||
|
overlapping_sequence = new_sequence[j1:j2 + 1]
|
||||||
|
added_chunk = new_sequence[j2 + 1:]
|
||||||
|
|
||||||
|
# print(past_sequence)
|
||||||
|
# print(new_sequence)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix)))
|
||||||
|
print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk)))
|
||||||
|
print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk)))
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Remove interval [sink_length, sink_length + removed_length) from the context
|
||||||
|
# Subtract removed_length from model.n_tokens
|
||||||
|
model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length)
|
||||||
|
model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length)
|
||||||
|
|
||||||
|
new_sequence = new_sequence.tolist()
|
||||||
|
model.input_ids[:j2 + 1] = new_sequence[:j2 + 1]
|
||||||
|
model.n_tokens = j2 + 1
|
||||||
|
|
||||||
|
return new_sequence[:j2 + 1]
|
||||||
|
else:
|
||||||
|
return past_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def find_prefix_length(past_seq, seq_tensor):
|
||||||
|
'''
|
||||||
|
Given two torch tensors, finds the length of the longest
|
||||||
|
common prefix between the two.
|
||||||
|
'''
|
||||||
|
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||||
|
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||||
|
if len(indices) > 0:
|
||||||
|
prefix_length = indices[0].item()
|
||||||
|
else:
|
||||||
|
prefix_length = min_length
|
||||||
|
|
||||||
|
return prefix_length
|
||||||
|
|
||||||
|
|
||||||
|
def find_longest_common_substring_indices(list1, list2):
|
||||||
|
'''
|
||||||
|
Given two lists, solves the Longest Common Substring problem.
|
||||||
|
|
||||||
|
It returns the indices where the substring starts and ends in
|
||||||
|
s1 and s2.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
ir, jr, ir2, jr2 = find_longest_common_substring_indices(s1, s2)
|
||||||
|
print(s1[ir:jr + 1])
|
||||||
|
print(s2[ir2:jr2 + 1])
|
||||||
|
|
||||||
|
Adapted from
|
||||||
|
https://rosettacode.org/wiki/Longest_common_substring#Python
|
||||||
|
'''
|
||||||
|
|
||||||
|
len_list1, len_list2 = len(list1), len(list2)
|
||||||
|
start_index_list1, end_index_list1 = 0, -1
|
||||||
|
start_index_list2, end_index_list2 = 0, -1
|
||||||
|
|
||||||
|
for index1 in range(len_list1):
|
||||||
|
try:
|
||||||
|
index2 = list2.index(list1[index1])
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
while index2 >= 0:
|
||||||
|
temp_index1, temp_index2 = index1, index2
|
||||||
|
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]:
|
||||||
|
if temp_index1 - index1 >= end_index_list1 - start_index_list1:
|
||||||
|
start_index_list1, end_index_list1 = index1, temp_index1
|
||||||
|
start_index_list2, end_index_list2 = index2, temp_index2
|
||||||
|
|
||||||
|
temp_index1 += 1
|
||||||
|
temp_index2 += 1
|
||||||
|
try:
|
||||||
|
index2 = list2.index(list1[index1], index2 + 1)
|
||||||
|
except ValueError:
|
||||||
|
break
|
||||||
|
|
||||||
|
return start_index_list1, end_index_list1, start_index_list2, end_index_list2
|
@ -2,6 +2,9 @@ from typing import Sequence
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.cache_utils import process_llamacpp_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
except:
|
except:
|
||||||
@ -58,6 +61,25 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||||||
self.n_tokens += n_tokens
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_generate(lib):
|
||||||
|
|
||||||
|
def my_generate(self, *args, **kwargs):
|
||||||
|
|
||||||
|
if shared.args.streaming_llm:
|
||||||
|
new_sequence = args[0]
|
||||||
|
past_sequence = self._input_ids
|
||||||
|
|
||||||
|
# Do the cache trimming for StreamingLLM
|
||||||
|
process_llamacpp_cache(self, new_sequence, past_sequence)
|
||||||
|
|
||||||
|
for output in self.original_generate(*args, **kwargs):
|
||||||
|
yield output
|
||||||
|
|
||||||
|
lib.Llama.original_generate = lib.Llama.generate
|
||||||
|
lib.Llama.generate = my_generate
|
||||||
|
|
||||||
|
|
||||||
for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]:
|
for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]:
|
||||||
if lib is not None:
|
if lib is not None:
|
||||||
lib.Llama.eval = eval_with_progress
|
lib.Llama.eval = eval_with_progress
|
||||||
|
monkey_patch_generate(lib)
|
||||||
|
@ -46,6 +46,8 @@ loaders_and_params = OrderedDict({
|
|||||||
'no_offload_kqv',
|
'no_offload_kqv',
|
||||||
'row_split',
|
'row_split',
|
||||||
'tensorcores',
|
'tensorcores',
|
||||||
|
'streaming_llm',
|
||||||
|
'attention_sink_size',
|
||||||
],
|
],
|
||||||
'llamacpp_HF': [
|
'llamacpp_HF': [
|
||||||
'n_ctx',
|
'n_ctx',
|
||||||
@ -69,6 +71,8 @@ loaders_and_params = OrderedDict({
|
|||||||
'no_offload_kqv',
|
'no_offload_kqv',
|
||||||
'row_split',
|
'row_split',
|
||||||
'tensorcores',
|
'tensorcores',
|
||||||
|
'streaming_llm',
|
||||||
|
'attention_sink_size',
|
||||||
'llamacpp_HF_info',
|
'llamacpp_HF_info',
|
||||||
],
|
],
|
||||||
'ExLlamav2_HF': [
|
'ExLlamav2_HF': [
|
||||||
|
@ -130,6 +130,8 @@ group.add_argument('--logits_all', action='store_true', help='Needs to be set fo
|
|||||||
group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
||||||
group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
||||||
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
|
group.add_argument('--streaming-llm', action='store_true', help='Activates StreamingLLM, which prevents the prompt from ever being reevaluated when old chat messages are removed due to the context length for the model being reached.')
|
||||||
|
group.add_argument('--attention-sink-size', type=int, default=5, help='Minimum attention sink length from StreamingLLM.')
|
||||||
|
|
||||||
# ExLlamaV2
|
# ExLlamaV2
|
||||||
group = parser.add_argument_group('ExLlamaV2')
|
group = parser.add_argument_group('ExLlamaV2')
|
||||||
|
@ -13,6 +13,7 @@ import transformers
|
|||||||
from transformers import LogitsProcessorList, is_torch_xpu_available
|
from transformers import LogitsProcessorList, is_torch_xpu_available
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules.cache_utils import process_llamacpp_cache
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
Iteratorize,
|
Iteratorize,
|
||||||
Stream,
|
Stream,
|
||||||
@ -364,6 +365,12 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
print(decode(input_ids[0], skip_special_tokens=False))
|
print(decode(input_ids[0], skip_special_tokens=False))
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
# Handle StreamingLLM for llamacpp_HF
|
||||||
|
if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm:
|
||||||
|
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids)
|
||||||
|
shared.model.past_seq = torch.tensor(tmp)
|
||||||
|
shared.model.save_cache()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat and not shared.is_seq2seq:
|
if not is_chat and not shared.is_seq2seq:
|
||||||
|
@ -97,6 +97,8 @@ def list_model_elements():
|
|||||||
'no_offload_kqv',
|
'no_offload_kqv',
|
||||||
'row_split',
|
'row_split',
|
||||||
'tensorcores',
|
'tensorcores',
|
||||||
|
'streaming_llm',
|
||||||
|
'attention_sink_size',
|
||||||
'hqq_backend',
|
'hqq_backend',
|
||||||
]
|
]
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
|
@ -117,6 +117,8 @@ def create_ui():
|
|||||||
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
|
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
|
||||||
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
||||||
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards.')
|
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards.')
|
||||||
|
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||||
|
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size)
|
||||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
|
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
|
||||||
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
||||||
|
Loading…
Reference in New Issue
Block a user