General improvements

This commit is contained in:
oobabooga 2023-03-31 14:27:01 -03:00
parent 7fa5d96c22
commit 9d1dcf880a
3 changed files with 17 additions and 36 deletions

View File

@ -1,10 +1,10 @@
import os
from pathlib import Path from pathlib import Path
import modules.shared as shared
from modules.callbacks import Iteratorize
import llamacpp import llamacpp
import modules.shared as shared
from modules.callbacks import Iteratorize
class LlamaCppTokenizer: class LlamaCppTokenizer:
"""A thin wrapper over the llamacpp tokenizer""" """A thin wrapper over the llamacpp tokenizer"""
@ -37,18 +37,18 @@ class LlamaCppModel:
result = self() result = self()
result.model = _model result.model = _model
result.params = params
tokenizer = LlamaCppTokenizer.from_model(_model) tokenizer = LlamaCppTokenizer.from_model(_model)
return result, tokenizer return result, tokenizer
# TODO: Allow passing in params for each inference def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
def generate(self, context="", num_tokens=10, callback=None): params = self.params
# params = self.params params.n_predict = token_count
# params.n_predict = token_count params.top_p = top_p
# params.top_p = top_p params.top_k = top_k
# params.top_k = top_k params.temp = temperature
# params.temp = temperature params.repeat_penalty = repetition_penalty
# params.repeat_penalty = repetition_penalty
#params.repeat_last_n = repeat_last_n #params.repeat_last_n = repeat_last_n
# model.params = params # model.params = params
@ -58,7 +58,7 @@ class LlamaCppModel:
output = "" output = ""
is_end_of_text = False is_end_of_text = False
ctr = 0 ctr = 0
while ctr < num_tokens and not is_end_of_text: while ctr < token_count and not is_end_of_text:
if self.model.has_unconsumed_input(): if self.model.has_unconsumed_input():
self.model.ingest_all_pending_input() self.model.ingest_all_pending_input()
else: else:
@ -68,14 +68,13 @@ class LlamaCppModel:
is_end_of_text = token == self.model.token_eos() is_end_of_text = token == self.model.token_eos()
if callback: if callback:
callback(text) callback(text)
output += text
ctr += 1 ctr += 1
return output return output
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator: with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = kwargs['context'] reply = ''
for token in generator: for token in generator:
reply += token reply += token
yield reply yield reply

View File

@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
return max_length return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if shared.is_RWKV or shared.is_llamacpp: if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids)) input_ids = np.array(input_ids).reshape(1, len(input_ids))
return input_ids return input_ids
@ -116,7 +116,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if shared.is_RWKV: if any((shared.is_RWKV, shared.is_llamacpp)):
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
@ -142,24 +142,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
input_ids = encode(question) input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return return
elif shared.is_llamacpp:
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, num_tokens=max_new_tokens)
yield formatted_outputs(reply, shared.model_name)
else:
if not (shared.args.chat or shared.args.cai_chat):
yield formatted_outputs(question, shared.model_name)
for reply in shared.model.generate_with_streaming(context=question, num_tokens=max_new_tokens):
yield formatted_outputs(reply, shared.model_name)
except Exception as e:
print(e)
finally:
t1 = time.time()
output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
return
input_ids = encode(question, max_new_tokens) input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids original_input_ids = input_ids

View File

@ -2,6 +2,7 @@ accelerate==0.18.0
bitsandbytes==0.37.2 bitsandbytes==0.37.2
flexgen==0.1.7 flexgen==0.1.7
gradio==3.23.0 gradio==3.23.0
llamacpp==0.1.10
markdown markdown
numpy numpy
peft==0.2.0 peft==0.2.0
@ -11,5 +12,4 @@ safetensors==0.3.0
sentencepiece sentencepiece
tqdm tqdm
datasets datasets
llamacpp>=0.1.9
git+https://github.com/huggingface/transformers git+https://github.com/huggingface/transformers