mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 15:48:04 +01:00
Fix StreamingLLM when content is removed from the beginning of the prompt
This commit is contained in:
parent
d828844a6f
commit
d890c99b53
@ -19,12 +19,12 @@ def process_llamacpp_cache(model, new_sequence, past_sequence):
|
|||||||
past_sequence = torch.tensor(past_sequence)
|
past_sequence = torch.tensor(past_sequence)
|
||||||
|
|
||||||
prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1])
|
prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1])
|
||||||
sink_length = prefix_length
|
sink_length = max(prefix_length, shared.args.attention_sink_size)
|
||||||
if sink_length < shared.args.attention_sink_size:
|
|
||||||
sink_length = shared.args.attention_sink_size
|
|
||||||
|
|
||||||
removed_length = i1 - sink_length
|
removed_length = i1 - sink_length
|
||||||
|
|
||||||
|
if removed_length <= 0:
|
||||||
|
return past_sequence.tolist()
|
||||||
|
|
||||||
matching_prefix = past_sequence[:prefix_length]
|
matching_prefix = past_sequence[:prefix_length]
|
||||||
removed_chunk = past_sequence[sink_length:i1]
|
removed_chunk = past_sequence[sink_length:i1]
|
||||||
overlapping_sequence = new_sequence[j1:j2 + 1]
|
overlapping_sequence = new_sequence[j1:j2 + 1]
|
||||||
@ -37,10 +37,11 @@ def process_llamacpp_cache(model, new_sequence, past_sequence):
|
|||||||
print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix)))
|
print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix)))
|
||||||
print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk)))
|
print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk)))
|
||||||
print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk)))
|
print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk)))
|
||||||
|
print('REMOVED LENGTH=', removed_length)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Remove interval [sink_length, sink_length + removed_length) from the context
|
# Remove interval [sink_length, sink_length + removed_length) from the context
|
||||||
# Subtract removed_length from model.n_tokens
|
# Update model.n_tokens
|
||||||
model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length)
|
model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length)
|
||||||
model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length)
|
model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user