Make long_replies ban the eos token as well

This commit is contained in:
oobabooga 2023-08-01 18:47:49 -07:00
parent 9ae0eab989
commit 6c521ce967

View File

@ -28,7 +28,7 @@ class MyLogits(LogitsProcessor):
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if input_ids.shape[-1] - initial_size < params["min_length"]: if input_ids.shape[-1] - initial_size < params["min_length"]:
scores[...,self.newline_id] = -1000 scores[...,self.newline_id] = -1000
# scores[...,shared.tokenizer.eos_token_id] = -1000 scores[...,shared.tokenizer.eos_token_id] = -1000
# probs = torch.softmax(scores, dim=-1, dtype=torch.float) # probs = torch.softmax(scores, dim=-1, dtype=torch.float)
# probs[0] /= probs[0].sum() # probs[0] /= probs[0].sum()