Optimize llamacpp_hf a bit

This commit is contained in:
oobabooga 2023-07-16 20:49:48 -07:00
parent 6a3edb0542
commit a199f21799

View File

@ -42,7 +42,6 @@ class LlamacppHF(PreTrainedModel):
# Make the forward call # Make the forward call
seq_tensor = torch.tensor(seq) seq_tensor = torch.tensor(seq)
self.cache = seq_tensor
if labels is None: if labels is None:
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]): if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
self.model.reset() self.model.reset()
@ -50,13 +49,15 @@ class LlamacppHF(PreTrainedModel):
else: else:
self.model.eval([seq[-1]]) self.model.eval([seq[-1]])
logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device) logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device)
else: else:
self.model.reset() self.model.reset()
self.model.eval(seq) self.model.eval(seq)
logits = torch.tensor(self.model.eval_logits) logits = torch.tensor(self.model.eval_logits)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device) logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
self.cache = seq_tensor
# Based on transformers/models/llama/modeling_llama.py # Based on transformers/models/llama/modeling_llama.py
loss = None loss = None
if labels is not None: if labels is not None: