Add new llama.cpp library (2048 context, temperature, etc now work)

This commit is contained in:
oobabooga 2023-04-06 13:12:14 -03:00
parent 39f3fec913
commit 03cb44fc8c
3 changed files with 67 additions and 1 deletions

View File

@ -0,0 +1,65 @@
'''
Based on
https://github.com/abetlen/llama-cpp-python
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
import multiprocessing
from llama_cpp import Llama
from modules import shared
from modules.callbacks import Iteratorize
class LlamaCppModel:
def __init__(self):
self.initialized = False
@classmethod
def from_pretrained(self, path):
result = self()
params = {
'model_path': str(path),
'n_ctx': 2048,
'seed': 0,
'n_threads': shared.args.threads or None
}
self.model = Llama(**params)
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
if type(context) is str:
context = context.encode()
tokens = self.model.tokenize(context)
output = b""
count = 0
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
text = self.model.detokenize([token])
output += text
if callback:
callback(text.decode())
count += 1
if count >= token_count or (token == self.model.token_eos()):
break
return output.decode()
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = ''
for token in generator:
reply += token
yield reply

View File

@ -103,7 +103,7 @@ def load_model(model_name):
# llamacpp model # llamacpp model
elif shared.is_llamacpp: elif shared.is_llamacpp:
from modules.llamacpp_model import LlamaCppModel from modules.llamacpp_model_alternative import LlamaCppModel
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0] model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n") print(f"llama.cpp weights detected: {model_file}\n")

View File

@ -4,6 +4,7 @@ datasets
flexgen==0.1.7 flexgen==0.1.7
gradio==3.24.1 gradio==3.24.1
llamacpp==0.1.11 llamacpp==0.1.11
llama-cpp-python==0.1.23
markdown markdown
numpy numpy
peft==0.2.0 peft==0.2.0