Enable special token support for exllamav2 (#4314)

This commit is contained in:
Johan 2023-10-21 06:54:06 +02:00 committed by GitHub
parent 8f6405d2fa
commit 1d5a015ce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -64,7 +64,7 @@ class Exllamav2Model:
return result, result return result, result
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
return self.tokenizer.encode(string, add_bos=True) return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)
def decode(self, ids, **kwargs): def decode(self, ids, **kwargs):
if isinstance(ids, list): if isinstance(ids, list):
@ -72,7 +72,7 @@ class Exllamav2Model:
elif isinstance(ids, torch.Tensor) and ids.numel() == 1: elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1) ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0] return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
def get_logits(self, token_ids, **kwargs): def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0 self.cache.current_seq_len = 0
@ -97,7 +97,7 @@ class Exllamav2Model:
if len(to_ban) > 0: if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban) settings.disallow_tokens(self.tokenizer, to_ban)
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token']) ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
ids = ids[:, -get_max_prompt_length(state):] ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1] initial_len = ids.shape[-1]
@ -119,7 +119,7 @@ class Exllamav2Model:
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''): if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True has_leading_space = True
decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0] decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0]
if has_leading_space: if has_leading_space:
decoded_text = ' ' + decoded_text decoded_text = ' ' + decoded_text

View File

@ -231,6 +231,7 @@ loaders_samplers = {
'ban_eos_token', 'ban_eos_token',
'add_bos_token', 'add_bos_token',
'custom_token_bans', 'custom_token_bans',
'skip_special_tokens',
'auto_max_new_tokens', 'auto_max_new_tokens',
}, },
'ExLlamav2_HF': { 'ExLlamav2_HF': {