From 6015616338f4a0c37002fb6a81e1bc555f477284 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 6 Jun 2023 13:06:05 -0300 Subject: [PATCH] Style changes --- modules/llamacpp_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index ea42dafb..4f2de155 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -25,7 +25,6 @@ class LlamaCppModel: @classmethod def from_pretrained(self, path): result = self() - cache_capacity = 0 if shared.args.cache_capacity is not None: if 'GiB' in shared.args.cache_capacity: @@ -36,7 +35,6 @@ class LlamaCppModel: cache_capacity = int(shared.args.cache_capacity) logger.info("Cache capacity is " + str(cache_capacity) + " bytes") - params = { 'model_path': str(path), 'n_ctx': shared.args.n_ctx, @@ -47,6 +45,7 @@ class LlamaCppModel: 'use_mlock': shared.args.mlock, 'n_gpu_layers': shared.args.n_gpu_layers } + self.model = Llama(**params) if cache_capacity > 0: self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) @@ -57,6 +56,7 @@ class LlamaCppModel: def encode(self, string): if type(string) is str: string = string.encode() + return self.model.tokenize(string) def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None): @@ -73,12 +73,14 @@ class LlamaCppModel: mirostat_eta=mirostat_eta, stream=True ) + output = "" for completion_chunk in completion_chunks: text = completion_chunk['choices'][0]['text'] output += text if callback: callback(text) + return output def generate_with_streaming(self, **kwargs):