Update to use new llamacpp API

This commit is contained in:
Thomas Antony 2023-03-29 21:20:22 +01:00
parent 79fa2b6d7e
commit 7fa5d96c22

View File

@ -8,16 +8,16 @@ import llamacpp
class LlamaCppTokenizer: class LlamaCppTokenizer:
"""A thin wrapper over the llamacpp tokenizer""" """A thin wrapper over the llamacpp tokenizer"""
def __init__(self, model: llamacpp.PyLLAMA): def __init__(self, model: llamacpp.LlamaInference):
self._tokenizer = model.get_tokenizer() self._tokenizer = model.get_tokenizer()
self.eos_token_id = 2 self.eos_token_id = 2
self.bos_token_id = 0 self.bos_token_id = 0
@classmethod @classmethod
def from_model(cls, model: llamacpp.PyLLAMA): def from_model(cls, model: llamacpp.LlamaInference):
return cls(model) return cls(model)
def encode(self, prompt): def encode(self, prompt: str):
return self._tokenizer.tokenize(prompt) return self._tokenizer.tokenize(prompt)
def decode(self, ids): def decode(self, ids):
@ -30,21 +30,10 @@ class LlamaCppModel:
@classmethod @classmethod
def from_pretrained(self, path): def from_pretrained(self, path):
params = llamacpp.gpt_params( params = llamacpp.InferenceParams()
str(path), # model params.path_model = str(path)
2048, # ctx_size
200, # n_predict
40, # top_k
0.95, # top_p
0.80, # temp
1.30, # repeat_penalty
-1, # seed
8, # threads
64, # repeat_last_n
8, # batch_size
)
_model = llamacpp.PyLLAMA(params) _model = llamacpp.LlamaInference(params)
result = self() result = self()
result.model = _model result.model = _model
@ -63,22 +52,20 @@ class LlamaCppModel:
# params.repeat_last_n = repeat_last_n # params.repeat_last_n = repeat_last_n
# model.params = params # model.params = params
if not self.initialized:
self.model.add_bos() self.model.add_bos()
self.model.update_input(context) self.model.update_input(context)
if not self.initialized:
self.model.prepare_context()
self.initialized = True
output = "" output = ""
is_end_of_text = False is_end_of_text = False
ctr = 0 ctr = 0
while not self.model.is_finished() and ctr < num_tokens and not is_end_of_text: while ctr < num_tokens and not is_end_of_text:
if self.model.has_unconsumed_input(): if self.model.has_unconsumed_input():
self.model.ingest_all_pending_input(False) self.model.ingest_all_pending_input()
else: else:
text, is_end_of_text = self.model.infer_text() self.model.eval()
token = self.model.sample()
text = self.model.token_to_str(token)
is_end_of_text = token == self.model.token_eos()
if callback: if callback:
callback(text) callback(text)
output += text output += text