Reorder some functions

This commit is contained in:
oobabooga 2023-09-19 13:13:03 -07:00
parent e2fddd9584
commit 13ac55fa18
4 changed files with 33 additions and 34 deletions

View File

@ -85,6 +85,22 @@ class ExllamaModel:
result.generator = generator result.generator = generator
return result, result return result, result
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state): def generate_with_streaming(self, prompt, state):
# The cache batch size must be 2 for CFG and 1 otherwise # The cache batch size must be 2 for CFG and 1 otherwise
@ -200,19 +216,3 @@ class ExllamaModel:
pass pass
return output return output
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()

View File

@ -62,6 +62,22 @@ class Exllamav2Model:
result.generator = generator result.generator = generator
return result, result return result, result
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, add_bos=True)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state): def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings() settings = ExLlamaV2Sampler.Settings()
settings.temperature = state['temperature'] settings.temperature = state['temperature']
@ -114,19 +130,3 @@ class Exllamav2Model:
pass pass
return output return output
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, add_bos=True)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()

View File

@ -1,5 +1,4 @@
import gc import gc
import hashlib
import os import os
import re import re
import time import time