mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
ExLlamav2_HF: Convert logits to FP32 (#4310)
This commit is contained in:
parent
c0ffb77fd8
commit
ae8cd449ae
@ -108,10 +108,10 @@ class Exllamav2HF(PreTrainedModel):
|
|||||||
if len(seq_tensor) > 1:
|
if len(seq_tensor) > 1:
|
||||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||||
|
|
||||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
|
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
|
||||||
else:
|
else:
|
||||||
ex_cache.current_seq_len = 0
|
ex_cache.current_seq_len = 0
|
||||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
|
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
|
||||||
|
|
||||||
if is_negative:
|
if is_negative:
|
||||||
self.past_seq_negative = seq_tensor
|
self.past_seq_negative = seq_tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user