mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Fix exllama tokenizers (#3954)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
8d85425e09
commit
ed6b6411fb
@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import version as torch_version
|
from torch import version as torch_version
|
||||||
|
|
||||||
@ -111,7 +112,7 @@ class ExllamaModel:
|
|||||||
if state['custom_token_bans']:
|
if state['custom_token_bans']:
|
||||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
if len(to_ban) > 0:
|
if len(to_ban) > 0:
|
||||||
self.generator.disallow_tokens(self.tokenizer, to_ban)
|
self.generator.disallow_tokens(to_ban)
|
||||||
|
|
||||||
# Case 1: no CFG
|
# Case 1: no CFG
|
||||||
if state['guidance_scale'] == 1:
|
if state['guidance_scale'] == 1:
|
||||||
@ -119,6 +120,11 @@ class ExllamaModel:
|
|||||||
|
|
||||||
# Tokenizing the input
|
# Tokenizing the input
|
||||||
ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
|
ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len)
|
||||||
|
if state['add_bos_token']:
|
||||||
|
ids = torch.cat(
|
||||||
|
[torch.tensor([[self.tokenizer.bos_token_id]]).to(ids.device),
|
||||||
|
ids], dim=1
|
||||||
|
).to(torch.int64)
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
if state['auto_max_new_tokens']:
|
if state['auto_max_new_tokens']:
|
||||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||||
@ -148,7 +154,12 @@ class ExllamaModel:
|
|||||||
alpha = state['guidance_scale']
|
alpha = state['guidance_scale']
|
||||||
prompts = [prompt, state['negative_prompt'] or '']
|
prompts = [prompt, state['negative_prompt'] or '']
|
||||||
|
|
||||||
ids, mask = self.tokenizer.encode(prompts, return_mask=True, max_seq_len=self.model.config.max_seq_len)
|
ids, mask = self.tokenizer.encode(
|
||||||
|
prompts,
|
||||||
|
return_mask=True,
|
||||||
|
max_seq_len=self.model.config.max_seq_len,
|
||||||
|
add_bos=state['add_bos_token']
|
||||||
|
)
|
||||||
if state['auto_max_new_tokens']:
|
if state['auto_max_new_tokens']:
|
||||||
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
||||||
else:
|
else:
|
||||||
@ -188,7 +199,12 @@ class ExllamaModel:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len)
|
return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len, add_bos=True)
|
||||||
|
|
||||||
def decode(self, string, **kwargs):
|
def decode(self, ids, **kwargs):
|
||||||
return self.tokenizer.decode(string)[0]
|
if isinstance(ids, int):
|
||||||
|
ids = torch.tensor([[ids]])
|
||||||
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(ids)[0]
|
||||||
|
@ -48,7 +48,7 @@ class Exllamav2Model:
|
|||||||
result.cache = cache
|
result.cache = cache
|
||||||
result.tokenizer = tokenizer
|
result.tokenizer = tokenizer
|
||||||
result.generator = generator
|
result.generator = generator
|
||||||
return result, tokenizer
|
return result, result
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
settings = ExLlamaV2Sampler.Settings()
|
settings = ExLlamaV2Sampler.Settings()
|
||||||
@ -65,7 +65,7 @@ class Exllamav2Model:
|
|||||||
if len(to_ban) > 0:
|
if len(to_ban) > 0:
|
||||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||||
|
|
||||||
ids = self.tokenizer.encode(prompt)
|
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'])
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
initial_len = ids.shape[-1]
|
initial_len = ids.shape[-1]
|
||||||
|
|
||||||
@ -104,7 +104,12 @@ class Exllamav2Model:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
return self.tokenizer.encode(string)
|
return self.tokenizer.encode(string, add_bos=True)
|
||||||
|
|
||||||
def decode(self, string, **kwargs):
|
def decode(self, ids, **kwargs):
|
||||||
return self.tokenizer.decode(string)[0]
|
if isinstance(ids, int):
|
||||||
|
ids = torch.tensor([[ids]])
|
||||||
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(ids)[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user