mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Optimize llamacpp_hf a bit
This commit is contained in:
parent
6a3edb0542
commit
a199f21799
@ -42,7 +42,6 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
|
|
||||||
# Make the forward call
|
# Make the forward call
|
||||||
seq_tensor = torch.tensor(seq)
|
seq_tensor = torch.tensor(seq)
|
||||||
self.cache = seq_tensor
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
|
if self.cache is None or not torch.equal(self.cache, seq_tensor[:-1]):
|
||||||
self.model.reset()
|
self.model.reset()
|
||||||
@ -50,13 +49,15 @@ class LlamacppHF(PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
self.model.eval([seq[-1]])
|
self.model.eval([seq[-1]])
|
||||||
|
|
||||||
logits = torch.tensor(self.model.eval_logits)[-1].view(1, 1, -1).to(kwargs['input_ids'].device)
|
logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device)
|
||||||
else:
|
else:
|
||||||
self.model.reset()
|
self.model.reset()
|
||||||
self.model.eval(seq)
|
self.model.eval(seq)
|
||||||
logits = torch.tensor(self.model.eval_logits)
|
logits = torch.tensor(self.model.eval_logits)
|
||||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
|
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
|
||||||
|
|
||||||
|
self.cache = seq_tensor
|
||||||
|
|
||||||
# Based on transformers/models/llama/modeling_llama.py
|
# Based on transformers/models/llama/modeling_llama.py
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user