Add decode functions to llama.cpp/exllama

This commit is contained in:
oobabooga 2023-07-07 09:11:30 -07:00
parent 1ba2e88551
commit b6643e5039
2 changed files with 6 additions and 0 deletions

View File

@ -120,3 +120,6 @@ class ExllamaModel:
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
return self.tokenizer.encode(string) return self.tokenizer.encode(string)
def decode(self, string, **kwargs):
return self.tokenizer.decode(string)[0]

View File

@ -65,6 +65,9 @@ class LlamaCppModel:
return self.model.tokenize(string) return self.model.tokenize(string)
def decode(self, tokens):
return self.model.detokenize(tokens)
def generate(self, prompt, state, callback=None): def generate(self, prompt, state, callback=None):
prompt = prompt if type(prompt) is str else prompt.decode() prompt = prompt if type(prompt) is str else prompt.decode()
completion_chunks = self.model.create_completion( completion_chunks = self.model.create_completion(