mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Revert "Fix exllama_hf gibbersh above 2048 context, and works >5000 context. (#2913)"
This reverts commit 37a16d23a7
.
This commit is contained in:
parent
7b048dcf67
commit
20740ab16e
@ -54,15 +54,7 @@ class ExllamaHF(PreTrainedModel):
|
|||||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = ExLlamaCache(self.ex_model)
|
cache = ExLlamaCache(self.ex_model)
|
||||||
|
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
|
||||||
nseq = seq[:-1]
|
|
||||||
for seqs in [nseq[i : i + 2048] for i in range(0, len(nseq), 2048)]:
|
|
||||||
self.ex_model.forward(
|
|
||||||
torch.tensor([seqs], dtype=torch.long),
|
|
||||||
cache,
|
|
||||||
preprocess_only=True,
|
|
||||||
lora=self.lora,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user