mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Optimize llamacpp_hf a bit
This commit is contained in:
parent
6a3edb0542
commit
a199f21799
@ -42,7 +42,6 @@ class LlamacppHF(PreTrainedModel):
|
||||
|
||||
# Make the forward call
|
||||
seq_tensor = torch.tensor(seq)
|
||||
self.cache = seq_tensor
|
||||
if labels is None:
|
||||
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
|
||||
self.model.reset()
|
||||
@ -50,13 +49,15 @@ class LlamacppHF(PreTrainedModel):
|
||||
else:
|
||||
self.model.eval([seq[-1]])
|
||||
|
||||
logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||
logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||
else:
|
||||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
logits = torch.tensor(self.model.eval_logits)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
|
||||
|
||||
self.cache = seq_tensor
|
||||
|
||||
# Based on transformers/models/llama/modeling_llama.py
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user