mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Reorder some functions
This commit is contained in:
parent
e2fddd9584
commit
13ac55fa18
@ -85,6 +85,22 @@ class ExllamaModel:
|
||||
result.generator = generator
|
||||
return result, result
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
ids = torch.tensor([ids])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
|
||||
# The cache batch size must be 2 for CFG and 1 otherwise
|
||||
@ -200,19 +216,3 @@ class ExllamaModel:
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
ids = torch.tensor([ids])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
||||
|
@ -62,6 +62,22 @@ class Exllamav2Model:
|
||||
result.generator = generator
|
||||
return result, result
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, add_bos=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
ids = torch.tensor([ids])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.temperature = state['temperature']
|
||||
@ -114,19 +130,3 @@ class Exllamav2Model:
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, add_bos=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
ids = torch.tensor([ids])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||
|
@ -70,7 +70,7 @@ def load_metadata(fname):
|
||||
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
|
||||
ti_data_count = struct.unpack("<Q", file.read(8))[0]
|
||||
kv_data_count = struct.unpack("<Q", file.read(8))[0]
|
||||
|
||||
|
||||
if GGUF_VERSION == 1:
|
||||
raise Exception('You are using an outdated GGUF, please download a new one.')
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
Loading…
Reference in New Issue
Block a user