Update to use new llamacpp API

This commit is contained in:
Thomas Antony 2023-03-29 21:20:22 +01:00
parent 79fa2b6d7e
commit 7fa5d96c22

View File

@ -8,16 +8,16 @@ import llamacpp
class LlamaCppTokenizer:
"""A thin wrapper over the llamacpp tokenizer"""
def __init__(self, model: llamacpp.PyLLAMA):
def __init__(self, model: llamacpp.LlamaInference):
self._tokenizer = model.get_tokenizer()
self.eos_token_id = 2
self.bos_token_id = 0
@classmethod
def from_model(cls, model: llamacpp.PyLLAMA):
def from_model(cls, model: llamacpp.LlamaInference):
return cls(model)
def encode(self, prompt):
def encode(self, prompt: str):
return self._tokenizer.tokenize(prompt)
def decode(self, ids):
@ -30,21 +30,10 @@ class LlamaCppModel:
@classmethod
def from_pretrained(self, path):
params = llamacpp.gpt_params(
str(path), # model
2048, # ctx_size
200, # n_predict
40, # top_k
0.95, # top_p
0.80, # temp
1.30, # repeat_penalty
-1, # seed
8, # threads
64, # repeat_last_n
8, # batch_size
)
params = llamacpp.InferenceParams()
params.path_model = str(path)
_model = llamacpp.PyLLAMA(params)
_model = llamacpp.LlamaInference(params)
result = self()
result.model = _model
@ -63,22 +52,20 @@ class LlamaCppModel:
# params.repeat_last_n = repeat_last_n
# model.params = params
if not self.initialized:
self.model.add_bos()
self.model.update_input(context)
if not self.initialized:
self.model.prepare_context()
self.initialized = True
output = ""
is_end_of_text = False
ctr = 0
while not self.model.is_finished() and ctr < num_tokens and not is_end_of_text:
while ctr < num_tokens and not is_end_of_text:
if self.model.has_unconsumed_input():
self.model.ingest_all_pending_input(False)
self.model.ingest_all_pending_input()
else:
text, is_end_of_text = self.model.infer_text()
self.model.eval()
token = self.model.sample()
text = self.model.token_to_str(token)
is_end_of_text = token == self.model.token_eos()
if callback:
callback(text)
output += text