From 6c521ce96787552a9604c344b9949945ef359a59 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 1 Aug 2023 18:47:49 -0700 Subject: [PATCH] Make long_replies ban the eos token as well --- extensions/long_replies/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/long_replies/script.py b/extensions/long_replies/script.py index 035e8c9e..a30b05a7 100644 --- a/extensions/long_replies/script.py +++ b/extensions/long_replies/script.py @@ -28,7 +28,7 @@ class MyLogits(LogitsProcessor): def __call__(self, input_ids, scores): if input_ids.shape[-1] - initial_size < params["min_length"]: scores[...,self.newline_id] = -1000 - # scores[...,shared.tokenizer.eos_token_id] = -1000 + scores[...,shared.tokenizer.eos_token_id] = -1000 # probs = torch.softmax(scores, dim=-1, dtype=torch.float) # probs[0] /= probs[0].sum()