Fix a bug in llama.cpp get_logits() function

This commit is contained in:
oobabooga 2023-11-30 11:21:40 -08:00
parent 000b77a17d
commit 092a2c3516

View File

@ -105,6 +105,7 @@ class LlamaCppModel:
return self.model.detokenize(ids).decode('utf-8') return self.model.detokenize(ids).decode('utf-8')
def get_logits(self, tokens): def get_logits(self, tokens):
self.model.reset()
self.model.eval(tokens) self.model.eval(tokens)
logits = self.model._scores logits = self.model._scores
logits = np.expand_dims(logits, 0) # batch dim is expected logits = np.expand_dims(logits, 0) # batch dim is expected