diff --git a/modules/exllama.py b/modules/exllama.py index f685a445..ecfb10a4 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -120,3 +120,6 @@ class ExllamaModel: def encode(self, string, **kwargs): return self.tokenizer.encode(string) + + def decode(self, string, **kwargs): + return self.tokenizer.decode(string)[0] diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 10a852db..4899ad99 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -65,6 +65,9 @@ class LlamaCppModel: return self.model.tokenize(string) + def decode(self, tokens): + return self.model.detokenize(tokens) + def generate(self, prompt, state, callback=None): prompt = prompt if type(prompt) is str else prompt.decode() completion_chunks = self.model.create_completion(