From 13ac55fa1805d7f4b87a43eb04a47d0d8b5ee50d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 19 Sep 2023 13:13:03 -0700 Subject: [PATCH] Reorder some functions --- modules/exllama.py | 32 ++++++++++++++++---------------- modules/exllamav2.py | 32 ++++++++++++++++---------------- modules/metadata_gguf.py | 2 +- modules/models.py | 1 - 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/modules/exllama.py b/modules/exllama.py index 4253e6ca..cb92344e 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -85,6 +85,22 @@ class ExllamaModel: result.generator = generator return result, result + def encode(self, string, **kwargs): + return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True) + + def decode(self, ids, **kwargs): + if isinstance(ids, list): + ids = torch.tensor([ids]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids)[0] + + def get_logits(self, token_ids, **kwargs): + self.cache.current_seq_len = 0 + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu() + def generate_with_streaming(self, prompt, state): # The cache batch size must be 2 for CFG and 1 otherwise @@ -200,19 +216,3 @@ class ExllamaModel: pass return output - - def encode(self, string, **kwargs): - return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True) - - def decode(self, ids, **kwargs): - if isinstance(ids, list): - ids = torch.tensor([ids]) - elif isinstance(ids, torch.Tensor) and ids.numel() == 1: - ids = ids.view(1, -1) - - return self.tokenizer.decode(ids)[0] - - def get_logits(self, token_ids, **kwargs): - self.cache.current_seq_len = 0 - self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) - return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu() diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 605a0927..be5f47e4 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -62,6 +62,22 @@ class Exllamav2Model: result.generator = generator return result, result + def encode(self, string, **kwargs): + return self.tokenizer.encode(string, add_bos=True) + + def decode(self, ids, **kwargs): + if isinstance(ids, list): + ids = torch.tensor([ids]) + elif isinstance(ids, torch.Tensor) and ids.numel() == 1: + ids = ids.view(1, -1) + + return self.tokenizer.decode(ids)[0] + + def get_logits(self, token_ids, **kwargs): + self.cache.current_seq_len = 0 + self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) + return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu() + def generate_with_streaming(self, prompt, state): settings = ExLlamaV2Sampler.Settings() settings.temperature = state['temperature'] @@ -114,19 +130,3 @@ class Exllamav2Model: pass return output - - def encode(self, string, **kwargs): - return self.tokenizer.encode(string, add_bos=True) - - def decode(self, ids, **kwargs): - if isinstance(ids, list): - ids = torch.tensor([ids]) - elif isinstance(ids, torch.Tensor) and ids.numel() == 1: - ids = ids.view(1, -1) - - return self.tokenizer.decode(ids)[0] - - def get_logits(self, token_ids, **kwargs): - self.cache.current_seq_len = 0 - self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True) - return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu() diff --git a/modules/metadata_gguf.py b/modules/metadata_gguf.py index f633d70c..f5fa3ce2 100644 --- a/modules/metadata_gguf.py +++ b/modules/metadata_gguf.py @@ -70,7 +70,7 @@ def load_metadata(fname): GGUF_VERSION = struct.unpack("