mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 15:48:04 +01:00
Optimize StreamingLLM by over 10x
This commit is contained in:
parent
afb51bd5d6
commit
cf0697936a
@ -1,10 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from numba import njit
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
def process_llamacpp_cache(model, new_sequence, past_sequence):
|
def process_llamacpp_cache(model, new_sequence, past_sequence):
|
||||||
|
if len(past_sequence) == 0 or len(new_sequence) == 0:
|
||||||
|
return past_sequence
|
||||||
|
|
||||||
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence)
|
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence)
|
||||||
overlap_length = i2 - i1 + 1
|
overlap_length = i2 - i1 + 1
|
||||||
|
|
||||||
@ -65,6 +68,7 @@ def find_prefix_length(past_seq, seq_tensor):
|
|||||||
return prefix_length
|
return prefix_length
|
||||||
|
|
||||||
|
|
||||||
|
@njit
|
||||||
def find_longest_common_substring_indices(list1, list2):
|
def find_longest_common_substring_indices(list1, list2):
|
||||||
'''
|
'''
|
||||||
Given two lists, solves the Longest Common Substring problem.
|
Given two lists, solves the Longest Common Substring problem.
|
||||||
@ -86,11 +90,13 @@ def find_longest_common_substring_indices(list1, list2):
|
|||||||
start_index_list1, end_index_list1 = 0, -1
|
start_index_list1, end_index_list1 = 0, -1
|
||||||
start_index_list2, end_index_list2 = 0, -1
|
start_index_list2, end_index_list2 = 0, -1
|
||||||
|
|
||||||
for index1 in range(len_list1):
|
# for index1 in tqdm(range(0, len_list1), desc="StreamingLLM prompt comparison", leave=False):
|
||||||
|
for index1 in range(0, len_list1):
|
||||||
try:
|
try:
|
||||||
index2 = list2.index(list1[index1])
|
index2 = list2.index(list1[index1])
|
||||||
except ValueError:
|
except:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
while index2 >= 0:
|
while index2 >= 0:
|
||||||
temp_index1, temp_index2 = index1, index2
|
temp_index1, temp_index2 = index1, index2
|
||||||
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]:
|
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]:
|
||||||
@ -102,7 +108,7 @@ def find_longest_common_substring_indices(list1, list2):
|
|||||||
temp_index2 += 1
|
temp_index2 += 1
|
||||||
try:
|
try:
|
||||||
index2 = list2.index(list1[index1], index2 + 1)
|
index2 = list2.index(list1[index1], index2 + 1)
|
||||||
except ValueError:
|
except:
|
||||||
break
|
break
|
||||||
|
|
||||||
return start_index_list1, end_index_list1, start_index_list2, end_index_list2
|
return start_index_list1, end_index_list1, start_index_list2, end_index_list2
|
||||||
|
@ -367,7 +367,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
|
|
||||||
# Handle StreamingLLM for llamacpp_HF
|
# Handle StreamingLLM for llamacpp_HF
|
||||||
if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm:
|
if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm:
|
||||||
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids)
|
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids.tolist())
|
||||||
shared.model.past_seq = torch.tensor(tmp)
|
shared.model.past_seq = torch.tensor(tmp)
|
||||||
shared.model.save_cache()
|
shared.model.save_cache()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user