mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Enable special token support for exllamav2 (#4314)
This commit is contained in:
parent
8f6405d2fa
commit
1d5a015ce7
@ -64,7 +64,7 @@ class Exllamav2Model:
|
|||||||
return result, result
|
return result, result
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
return self.tokenizer.encode(string, add_bos=True)
|
return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
def decode(self, ids, **kwargs):
|
||||||
if isinstance(ids, list):
|
if isinstance(ids, list):
|
||||||
@ -72,7 +72,7 @@ class Exllamav2Model:
|
|||||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||||
ids = ids.view(1, -1)
|
ids = ids.view(1, -1)
|
||||||
|
|
||||||
return self.tokenizer.decode(ids)[0]
|
return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
|
||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
def get_logits(self, token_ids, **kwargs):
|
||||||
self.cache.current_seq_len = 0
|
self.cache.current_seq_len = 0
|
||||||
@ -97,7 +97,7 @@ class Exllamav2Model:
|
|||||||
if len(to_ban) > 0:
|
if len(to_ban) > 0:
|
||||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||||
|
|
||||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'])
|
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
initial_len = ids.shape[-1]
|
initial_len = ids.shape[-1]
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ class Exllamav2Model:
|
|||||||
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||||
has_leading_space = True
|
has_leading_space = True
|
||||||
|
|
||||||
decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0]
|
decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0]
|
||||||
if has_leading_space:
|
if has_leading_space:
|
||||||
decoded_text = ' ' + decoded_text
|
decoded_text = ' ' + decoded_text
|
||||||
|
|
||||||
|
@ -231,6 +231,7 @@ loaders_samplers = {
|
|||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'custom_token_bans',
|
'custom_token_bans',
|
||||||
|
'skip_special_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
},
|
},
|
||||||
'ExLlamav2_HF': {
|
'ExLlamav2_HF': {
|
||||||
|
Loading…
Reference in New Issue
Block a user