mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Enable special token support for exllamav2 (#4314)
This commit is contained in:
parent
8f6405d2fa
commit
1d5a015ce7
@ -64,7 +64,7 @@ class Exllamav2Model:
|
||||
return result, result
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, add_bos=True)
|
||||
return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
@ -72,7 +72,7 @@ class Exllamav2Model:
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
@ -97,7 +97,7 @@ class Exllamav2Model:
|
||||
if len(to_ban) > 0:
|
||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||
|
||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'])
|
||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
initial_len = ids.shape[-1]
|
||||
|
||||
@ -119,7 +119,7 @@ class Exllamav2Model:
|
||||
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
|
||||
decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0]
|
||||
decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0]
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
|
@ -231,6 +231,7 @@ loaders_samplers = {
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'custom_token_bans',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'ExLlamav2_HF': {
|
||||
|
Loading…
Reference in New Issue
Block a user