diff --git a/.gitignore b/.gitignore index a9c47a5a..bfb6d027 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ torch-dumps */*/pycache* venv/ .venv/ +.vscode repositories settings.json diff --git a/modules/RWKV.py b/modules/RWKV.py index 8c7ea2b9..10c4c366 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -34,7 +34,7 @@ class RWKVModel: result.pipeline = pipeline return result - def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): args = PIPELINE_ARGS( temperature = temperature, top_p = top_p, diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py new file mode 100644 index 00000000..6b9b1b52 --- /dev/null +++ b/modules/llamacpp_model.py @@ -0,0 +1,80 @@ +from pathlib import Path + +import llamacpp + +import modules.shared as shared +from modules.callbacks import Iteratorize + + +class LlamaCppTokenizer: + """A thin wrapper over the llamacpp tokenizer""" + def __init__(self, model: llamacpp.LlamaInference): + self._tokenizer = model.get_tokenizer() + self.eos_token_id = 2 + self.bos_token_id = 0 + + @classmethod + def from_model(cls, model: llamacpp.LlamaInference): + return cls(model) + + def encode(self, prompt: str): + return self._tokenizer.tokenize(prompt) + + def decode(self, ids): + return self._tokenizer.detokenize(ids) + + +class LlamaCppModel: + def __init__(self): + self.initialized = False + + @classmethod + def from_pretrained(self, path): + params = llamacpp.InferenceParams() + params.path_model = str(path) + + _model = llamacpp.LlamaInference(params) + + result = self() + result.model = _model + result.params = params + + tokenizer = LlamaCppTokenizer.from_model(_model) + return result, tokenizer + + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None): + params = self.params + params.n_predict = token_count + params.top_p = top_p + params.top_k = top_k + params.temp = temperature + params.repeat_penalty = repetition_penalty + #params.repeat_last_n = repeat_last_n + + # model.params = params + self.model.add_bos() + self.model.update_input(context) + + output = "" + is_end_of_text = False + ctr = 0 + while ctr < token_count and not is_end_of_text: + if self.model.has_unconsumed_input(): + self.model.ingest_all_pending_input() + else: + self.model.eval() + token = self.model.sample() + text = self.model.token_to_str(token) + is_end_of_text = token == self.model.token_eos() + if callback: + callback(text) + ctr += 1 + + return output + + def generate_with_streaming(self, **kwargs): + with Iteratorize(self.generate, kwargs, callback=None) as generator: + reply = '' + for token in generator: + reply += token + yield reply diff --git a/modules/models.py b/modules/models.py index b19507db..80bbcab2 100644 --- a/modules/models.py +++ b/modules/models.py @@ -42,9 +42,10 @@ def load_model(model_name): t0 = time.time() shared.is_RWKV = 'rwkv-' in model_name.lower() + shared.is_llamacpp = model_name.lower().startswith(('llamacpp', 'alpaca-cpp')) # Default settings - if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): + if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: @@ -100,6 +101,18 @@ def load_model(model_name): model = load_quantized(model_name) + # llamacpp model + elif shared.is_llamacpp: + from modules.llamacpp_model import LlamaCppModel + + if model_name.lower().startswith('alpaca-cpp'): + model_file = f'models/{model_name}/ggml-alpaca-7b-q4.bin' + else: + model_file = f'models/{model_name}/ggml-model-q4_0.bin' + + model, tokenizer = LlamaCppModel.from_pretrained(Path(model_file)) + return model, tokenizer + # Custom else: params = {"low_cpu_mem_usage": True} diff --git a/modules/text_generation.py b/modules/text_generation.py index 7b5fcd6a..b8b2f496 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -22,7 +22,7 @@ def get_max_prompt_length(tokens): return max_length def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - if shared.is_RWKV: + if any((shared.is_RWKV, shared.is_llamacpp)): input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) return input_ids @@ -116,10 +116,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier - if shared.is_RWKV: + if any((shared.is_RWKV, shared.is_llamacpp)): try: if shared.args.no_stream: - reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) + reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) if not (shared.args.chat or shared.args.cai_chat): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) diff --git a/requirements.txt b/requirements.txt index 8d22d41f..ffa6b51a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ accelerate==0.18.0 bitsandbytes==0.37.2 flexgen==0.1.7 gradio==3.24.0 +llamacpp==0.1.11 markdown numpy peft==0.2.0