diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index f65ecb4e..6b9b1b52 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -1,10 +1,10 @@ -import os from pathlib import Path -import modules.shared as shared -from modules.callbacks import Iteratorize import llamacpp +import modules.shared as shared +from modules.callbacks import Iteratorize + class LlamaCppTokenizer: """A thin wrapper over the llamacpp tokenizer""" @@ -37,19 +37,19 @@ class LlamaCppModel: result = self() result.model = _model + result.params = params tokenizer = LlamaCppTokenizer.from_model(_model) return result, tokenizer - # TODO: Allow passing in params for each inference - def generate(self, context="", num_tokens=10, callback=None): - # params = self.params - # params.n_predict = token_count - # params.top_p = top_p - # params.top_k = top_k - # params.temp = temperature - # params.repeat_penalty = repetition_penalty - # params.repeat_last_n = repeat_last_n + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None): + params = self.params + params.n_predict = token_count + params.top_p = top_p + params.top_k = top_k + params.temp = temperature + params.repeat_penalty = repetition_penalty + #params.repeat_last_n = repeat_last_n # model.params = params self.model.add_bos() @@ -58,7 +58,7 @@ class LlamaCppModel: output = "" is_end_of_text = False ctr = 0 - while ctr < num_tokens and not is_end_of_text: + while ctr < token_count and not is_end_of_text: if self.model.has_unconsumed_input(): self.model.ingest_all_pending_input() else: @@ -68,14 +68,13 @@ class LlamaCppModel: is_end_of_text = token == self.model.token_eos() if callback: callback(text) - output += text ctr += 1 return output def generate_with_streaming(self, **kwargs): with Iteratorize(self.generate, kwargs, callback=None) as generator: - reply = kwargs['context'] + reply = '' for token in generator: reply += token yield reply diff --git a/modules/text_generation.py b/modules/text_generation.py index e18a76d7..8d54961e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -22,7 +22,7 @@ def get_max_prompt_length(tokens): return max_length def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - if shared.is_RWKV or shared.is_llamacpp: + if any((shared.is_RWKV, shared.is_llamacpp)): input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) return input_ids @@ -116,7 +116,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier - if shared.is_RWKV: + if any((shared.is_RWKV, shared.is_llamacpp)): try: if shared.args.no_stream: reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) @@ -142,24 +142,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question) print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") return - elif shared.is_llamacpp: - try: - if shared.args.no_stream: - reply = shared.model.generate(context=question, num_tokens=max_new_tokens) - yield formatted_outputs(reply, shared.model_name) - else: - if not (shared.args.chat or shared.args.cai_chat): - yield formatted_outputs(question, shared.model_name) - for reply in shared.model.generate_with_streaming(context=question, num_tokens=max_new_tokens): - yield formatted_outputs(reply, shared.model_name) - except Exception as e: - print(e) - finally: - t1 = time.time() - output = encode(reply)[0] - input_ids = encode(question) - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") - return input_ids = encode(question, max_new_tokens) original_input_ids = input_ids diff --git a/requirements.txt b/requirements.txt index e92c6889..08ee5d58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ accelerate==0.18.0 bitsandbytes==0.37.2 flexgen==0.1.7 gradio==3.23.0 +llamacpp==0.1.10 markdown numpy peft==0.2.0 @@ -11,5 +12,4 @@ safetensors==0.3.0 sentencepiece tqdm datasets -llamacpp>=0.1.9 git+https://github.com/huggingface/transformers