From 092a2c3516409a0d639dfb79c384ebc7ae3a4434 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 30 Nov 2023 11:21:40 -0800 Subject: [PATCH] Fix a bug in llama.cpp get_logits() function --- modules/llamacpp_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index aa0fedbf..8b133e98 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -105,6 +105,7 @@ class LlamaCppModel: return self.model.detokenize(ids).decode('utf-8') def get_logits(self, tokens): + self.model.reset() self.model.eval(tokens) logits = self.model._scores logits = np.expand_dims(logits, 0) # batch dim is expected