text-generation-webui/modules/llamacpp_model.py

83 lines
2.4 KiB
Python
Raw Normal View History

2023-03-31 21:18:05 -03:00
import multiprocessing
2023-03-18 23:42:10 -07:00
import llamacpp
2023-03-31 21:18:05 -03:00
from modules import shared
2023-03-31 14:27:01 -03:00
from modules.callbacks import Iteratorize
2023-03-18 23:42:10 -07:00
class LlamaCppTokenizer:
"""A thin wrapper over the llamacpp tokenizer"""
2023-03-29 21:20:22 +01:00
def __init__(self, model: llamacpp.LlamaInference):
2023-03-18 23:42:10 -07:00
self._tokenizer = model.get_tokenizer()
self.eos_token_id = 2
self.bos_token_id = 0
@classmethod
2023-03-29 21:20:22 +01:00
def from_model(cls, model: llamacpp.LlamaInference):
2023-03-18 23:42:10 -07:00
return cls(model)
2023-03-29 21:20:22 +01:00
def encode(self, prompt: str):
2023-03-18 23:42:10 -07:00
return self._tokenizer.tokenize(prompt)
def decode(self, ids):
return self._tokenizer.detokenize(ids)
class LlamaCppModel:
def __init__(self):
self.initialized = False
@classmethod
def from_pretrained(self, path):
2023-03-29 21:20:22 +01:00
params = llamacpp.InferenceParams()
params.path_model = str(path)
2023-03-31 21:18:05 -03:00
params.n_threads = shared.args.threads or multiprocessing.cpu_count() // 2
2023-03-29 21:20:22 +01:00
_model = llamacpp.LlamaInference(params)
2023-03-18 23:42:10 -07:00
result = self()
result.model = _model
2023-03-31 14:27:01 -03:00
result.params = params
2023-03-18 23:42:10 -07:00
tokenizer = LlamaCppTokenizer.from_model(_model)
return result, tokenizer
2023-03-31 14:27:01 -03:00
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
params = self.params
params.n_predict = token_count
params.top_p = top_p
params.top_k = top_k
params.temp = temperature
params.repeat_penalty = repetition_penalty
# params.repeat_last_n = repeat_last_n
2023-03-18 23:42:10 -07:00
# self.model.params = params
2023-03-29 21:20:22 +01:00
self.model.add_bos()
2023-03-18 23:42:10 -07:00
self.model.update_input(context)
output = ""
is_end_of_text = False
ctr = 0
2023-03-31 14:27:01 -03:00
while ctr < token_count and not is_end_of_text:
2023-03-18 23:42:10 -07:00
if self.model.has_unconsumed_input():
2023-03-29 21:20:22 +01:00
self.model.ingest_all_pending_input()
2023-03-18 23:42:10 -07:00
else:
2023-03-29 21:20:22 +01:00
self.model.eval()
token = self.model.sample()
text = self.model.token_to_str(token)
2023-03-31 18:43:45 -03:00
output += text
2023-03-29 21:20:22 +01:00
is_end_of_text = token == self.model.token_eos()
2023-03-18 23:42:10 -07:00
if callback:
callback(text)
ctr += 1
return output
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
2023-03-31 14:27:01 -03:00
reply = ''
2023-03-18 23:42:10 -07:00
for token in generator:
reply += token
yield reply