mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 08:56:52 +01:00
Optimize StreamingLLM by over 10x
This commit is contained in:
parent
afb51bd5d6
commit
cf0697936a
@ -1,10 +1,13 @@
|
||||
import torch
|
||||
from numba import njit
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def process_llamacpp_cache(model, new_sequence, past_sequence):
|
||||
if len(past_sequence) == 0 or len(new_sequence) == 0:
|
||||
return past_sequence
|
||||
|
||||
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence)
|
||||
overlap_length = i2 - i1 + 1
|
||||
|
||||
@ -65,6 +68,7 @@ def find_prefix_length(past_seq, seq_tensor):
|
||||
return prefix_length
|
||||
|
||||
|
||||
@njit
|
||||
def find_longest_common_substring_indices(list1, list2):
|
||||
'''
|
||||
Given two lists, solves the Longest Common Substring problem.
|
||||
@ -86,11 +90,13 @@ def find_longest_common_substring_indices(list1, list2):
|
||||
start_index_list1, end_index_list1 = 0, -1
|
||||
start_index_list2, end_index_list2 = 0, -1
|
||||
|
||||
for index1 in range(len_list1):
|
||||
# for index1 in tqdm(range(0, len_list1), desc="StreamingLLM prompt comparison", leave=False):
|
||||
for index1 in range(0, len_list1):
|
||||
try:
|
||||
index2 = list2.index(list1[index1])
|
||||
except ValueError:
|
||||
except:
|
||||
continue
|
||||
|
||||
while index2 >= 0:
|
||||
temp_index1, temp_index2 = index1, index2
|
||||
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]:
|
||||
@ -102,7 +108,7 @@ def find_longest_common_substring_indices(list1, list2):
|
||||
temp_index2 += 1
|
||||
try:
|
||||
index2 = list2.index(list1[index1], index2 + 1)
|
||||
except ValueError:
|
||||
except:
|
||||
break
|
||||
|
||||
return start_index_list1, end_index_list1, start_index_list2, end_index_list2
|
||||
|
@ -367,7 +367,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
|
||||
# Handle StreamingLLM for llamacpp_HF
|
||||
if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm:
|
||||
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids)
|
||||
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids.tolist())
|
||||
shared.model.past_seq = torch.tensor(tmp)
|
||||
shared.model.save_cache()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user