mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Reorder some functions
This commit is contained in:
parent
e2fddd9584
commit
13ac55fa18
@ -85,6 +85,22 @@ class ExllamaModel:
|
|||||||
result.generator = generator
|
result.generator = generator
|
||||||
return result, result
|
return result, result
|
||||||
|
|
||||||
|
def encode(self, string, **kwargs):
|
||||||
|
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
||||||
|
|
||||||
|
def decode(self, ids, **kwargs):
|
||||||
|
if isinstance(ids, list):
|
||||||
|
ids = torch.tensor([ids])
|
||||||
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(ids)[0]
|
||||||
|
|
||||||
|
def get_logits(self, token_ids, **kwargs):
|
||||||
|
self.cache.current_seq_len = 0
|
||||||
|
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||||
|
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
|
|
||||||
# The cache batch size must be 2 for CFG and 1 otherwise
|
# The cache batch size must be 2 for CFG and 1 otherwise
|
||||||
@ -200,19 +216,3 @@ class ExllamaModel:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
|
||||||
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
|
||||||
if isinstance(ids, list):
|
|
||||||
ids = torch.tensor([ids])
|
|
||||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
|
||||||
ids = ids.view(1, -1)
|
|
||||||
|
|
||||||
return self.tokenizer.decode(ids)[0]
|
|
||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
|
||||||
self.cache.current_seq_len = 0
|
|
||||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
|
||||||
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
|
||||||
|
@ -62,6 +62,22 @@ class Exllamav2Model:
|
|||||||
result.generator = generator
|
result.generator = generator
|
||||||
return result, result
|
return result, result
|
||||||
|
|
||||||
|
def encode(self, string, **kwargs):
|
||||||
|
return self.tokenizer.encode(string, add_bos=True)
|
||||||
|
|
||||||
|
def decode(self, ids, **kwargs):
|
||||||
|
if isinstance(ids, list):
|
||||||
|
ids = torch.tensor([ids])
|
||||||
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(ids)[0]
|
||||||
|
|
||||||
|
def get_logits(self, token_ids, **kwargs):
|
||||||
|
self.cache.current_seq_len = 0
|
||||||
|
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||||
|
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
settings = ExLlamaV2Sampler.Settings()
|
settings = ExLlamaV2Sampler.Settings()
|
||||||
settings.temperature = state['temperature']
|
settings.temperature = state['temperature']
|
||||||
@ -114,19 +130,3 @@ class Exllamav2Model:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
|
||||||
return self.tokenizer.encode(string, add_bos=True)
|
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
|
||||||
if isinstance(ids, list):
|
|
||||||
ids = torch.tensor([ids])
|
|
||||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
|
||||||
ids = ids.view(1, -1)
|
|
||||||
|
|
||||||
return self.tokenizer.decode(ids)[0]
|
|
||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
|
||||||
self.cache.current_seq_len = 0
|
|
||||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
|
||||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
|
||||||
|
@ -70,7 +70,7 @@ def load_metadata(fname):
|
|||||||
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
|
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
|
||||||
ti_data_count = struct.unpack("<Q", file.read(8))[0]
|
ti_data_count = struct.unpack("<Q", file.read(8))[0]
|
||||||
kv_data_count = struct.unpack("<Q", file.read(8))[0]
|
kv_data_count = struct.unpack("<Q", file.read(8))[0]
|
||||||
|
|
||||||
if GGUF_VERSION == 1:
|
if GGUF_VERSION == 1:
|
||||||
raise Exception('You are using an outdated GGUF, please download a new one.')
|
raise Exception('You are using an outdated GGUF, please download a new one.')
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import gc
|
import gc
|
||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
Loading…
Reference in New Issue
Block a user