General improvements

This commit is contained in:
oobabooga 2023-03-31 14:27:01 -03:00
parent 7fa5d96c22
commit 9d1dcf880a
3 changed files with 17 additions and 36 deletions

View File

@ -1,10 +1,10 @@
import os
from pathlib import Path
import modules.shared as shared
from modules.callbacks import Iteratorize
import llamacpp
import modules.shared as shared
from modules.callbacks import Iteratorize
class LlamaCppTokenizer:
"""A thin wrapper over the llamacpp tokenizer"""
@ -37,19 +37,19 @@ class LlamaCppModel:
result = self()
result.model = _model
result.params = params
tokenizer = LlamaCppTokenizer.from_model(_model)
return result, tokenizer
# TODO: Allow passing in params for each inference
def generate(self, context="", num_tokens=10, callback=None):
# params = self.params
# params.n_predict = token_count
# params.top_p = top_p
# params.top_k = top_k
# params.temp = temperature
# params.repeat_penalty = repetition_penalty
# params.repeat_last_n = repeat_last_n
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
params = self.params
params.n_predict = token_count
params.top_p = top_p
params.top_k = top_k
params.temp = temperature
params.repeat_penalty = repetition_penalty
#params.repeat_last_n = repeat_last_n
# model.params = params
self.model.add_bos()
@ -58,7 +58,7 @@ class LlamaCppModel:
output = ""
is_end_of_text = False
ctr = 0
while ctr < num_tokens and not is_end_of_text:
while ctr < token_count and not is_end_of_text:
if self.model.has_unconsumed_input():
self.model.ingest_all_pending_input()
else:
@ -68,14 +68,13 @@ class LlamaCppModel:
is_end_of_text = token == self.model.token_eos()
if callback:
callback(text)
output += text
ctr += 1
return output
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context']
reply = ''
for token in generator:
reply += token
yield reply

View File

@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if shared.is_RWKV or shared.is_llamacpp:
if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids))
return input_ids
@ -116,7 +116,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if shared.is_RWKV:
if any((shared.is_RWKV, shared.is_llamacpp)):
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
@ -142,24 +142,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
elif shared.is_llamacpp:
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, num_tokens=max_new_tokens)
yield formatted_outputs(reply, shared.model_name)
else:
if not (shared.args.chat or shared.args.cai_chat):
yield formatted_outputs(question, shared.model_name)
for reply in shared.model.generate_with_streaming(context=question, num_tokens=max_new_tokens):
yield formatted_outputs(reply, shared.model_name)
except Exception as e:
print(e)
finally:
t1 = time.time()
output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids

View File

@ -2,6 +2,7 @@ accelerate==0.18.0
bitsandbytes==0.37.2
flexgen==0.1.7
gradio==3.23.0
llamacpp==0.1.10
markdown
numpy
peft==0.2.0
@ -11,5 +12,4 @@ safetensors==0.3.0
sentencepiece
tqdm
datasets
llamacpp>=0.1.9
git+https://github.com/huggingface/transformers