ExLlama_HF (v1 and v2) prefix matching

This commit is contained in:
oobabooga 2023-09-19 13:12:19 -07:00
parent 5075087461
commit 03dc69edc5
2 changed files with 42 additions and 11 deletions

View File

@ -77,17 +77,33 @@ class ExllamaHF(PreTrainedModel):
seq = past_key_values + seq
seq_tensor = torch.tensor(seq)
reset = True
# Make the forward call
if labels is None:
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
ex_cache.current_seq_len = 0
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True, lora=self.lora)
if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache, lora=self.lora).to(input_ids.device)
if longest_prefix > 0:
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
if reset:
ex_cache.current_seq_len = 0
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, lora=self.lora)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, lora=self.lora).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False, lora=self.lora)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, lora=self.lora)
if is_negative:
self.past_seq_negative = seq_tensor

View File

@ -81,18 +81,33 @@ class Exllamav2HF(PreTrainedModel):
seq = past_key_values + seq
seq_tensor = torch.tensor(seq)
reset = True
# Make the forward call
if labels is None:
if past_seq is None or not torch.equal(past_seq, seq_tensor[:-1]):
ex_cache.current_seq_len = 0
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), ex_cache, preprocess_only=True)
if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), ex_cache).to(input_ids.device)
if longest_prefix > 0:
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
if reset:
ex_cache.current_seq_len = 0
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
# logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache, last_id_only=False)
logits = self.ex_model.forward(torch.tensor([seq], dtype=torch.long), ex_cache)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
if is_negative:
self.past_seq_negative = seq_tensor