Style changes

This commit is contained in:
oobabooga 2023-06-06 13:06:05 -03:00
parent f040073ef1
commit 6015616338

View File

@ -25,7 +25,6 @@ class LlamaCppModel:
@classmethod @classmethod
def from_pretrained(self, path): def from_pretrained(self, path):
result = self() result = self()
cache_capacity = 0 cache_capacity = 0
if shared.args.cache_capacity is not None: if shared.args.cache_capacity is not None:
if 'GiB' in shared.args.cache_capacity: if 'GiB' in shared.args.cache_capacity:
@ -36,7 +35,6 @@ class LlamaCppModel:
cache_capacity = int(shared.args.cache_capacity) cache_capacity = int(shared.args.cache_capacity)
logger.info("Cache capacity is " + str(cache_capacity) + " bytes") logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
params = { params = {
'model_path': str(path), 'model_path': str(path),
'n_ctx': shared.args.n_ctx, 'n_ctx': shared.args.n_ctx,
@ -47,6 +45,7 @@ class LlamaCppModel:
'use_mlock': shared.args.mlock, 'use_mlock': shared.args.mlock,
'n_gpu_layers': shared.args.n_gpu_layers 'n_gpu_layers': shared.args.n_gpu_layers
} }
self.model = Llama(**params) self.model = Llama(**params)
if cache_capacity > 0: if cache_capacity > 0:
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
@ -57,6 +56,7 @@ class LlamaCppModel:
def encode(self, string): def encode(self, string):
if type(string) is str: if type(string) is str:
string = string.encode() string = string.encode()
return self.model.tokenize(string) return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None): def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
@ -73,12 +73,14 @@ class LlamaCppModel:
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
stream=True stream=True
) )
output = "" output = ""
for completion_chunk in completion_chunks: for completion_chunk in completion_chunks:
text = completion_chunk['choices'][0]['text'] text = completion_chunk['choices'][0]['text']
output += text output += text
if callback: if callback:
callback(text) callback(text)
return output return output
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, **kwargs):