Enable special token support for exllamav2 (#4314)

This commit is contained in:
Johan 2023-10-21 06:54:06 +02:00 committed by GitHub
parent 8f6405d2fa
commit 1d5a015ce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -64,7 +64,7 @@ class Exllamav2Model:
return result, result
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, add_bos=True)
return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
@ -72,7 +72,7 @@ class Exllamav2Model:
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
@ -97,7 +97,7 @@ class Exllamav2Model:
if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban)
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'])
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1]
@ -119,7 +119,7 @@ class Exllamav2Model:
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True
decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0]
decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0]
if has_leading_space:
decoded_text = ' ' + decoded_text

View File

@ -231,6 +231,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'custom_token_bans',
'skip_special_tokens',
'auto_max_new_tokens',
},
'ExLlamav2_HF': {