mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Style changes
This commit is contained in:
parent
f040073ef1
commit
6015616338
@ -25,7 +25,6 @@ class LlamaCppModel:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, path):
|
def from_pretrained(self, path):
|
||||||
result = self()
|
result = self()
|
||||||
|
|
||||||
cache_capacity = 0
|
cache_capacity = 0
|
||||||
if shared.args.cache_capacity is not None:
|
if shared.args.cache_capacity is not None:
|
||||||
if 'GiB' in shared.args.cache_capacity:
|
if 'GiB' in shared.args.cache_capacity:
|
||||||
@ -36,7 +35,6 @@ class LlamaCppModel:
|
|||||||
cache_capacity = int(shared.args.cache_capacity)
|
cache_capacity = int(shared.args.cache_capacity)
|
||||||
|
|
||||||
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'model_path': str(path),
|
'model_path': str(path),
|
||||||
'n_ctx': shared.args.n_ctx,
|
'n_ctx': shared.args.n_ctx,
|
||||||
@ -47,6 +45,7 @@ class LlamaCppModel:
|
|||||||
'use_mlock': shared.args.mlock,
|
'use_mlock': shared.args.mlock,
|
||||||
'n_gpu_layers': shared.args.n_gpu_layers
|
'n_gpu_layers': shared.args.n_gpu_layers
|
||||||
}
|
}
|
||||||
|
|
||||||
self.model = Llama(**params)
|
self.model = Llama(**params)
|
||||||
if cache_capacity > 0:
|
if cache_capacity > 0:
|
||||||
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
||||||
@ -57,6 +56,7 @@ class LlamaCppModel:
|
|||||||
def encode(self, string):
|
def encode(self, string):
|
||||||
if type(string) is str:
|
if type(string) is str:
|
||||||
string = string.encode()
|
string = string.encode()
|
||||||
|
|
||||||
return self.model.tokenize(string)
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
||||||
@ -73,12 +73,14 @@ class LlamaCppModel:
|
|||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for completion_chunk in completion_chunks:
|
for completion_chunk in completion_chunks:
|
||||||
text = completion_chunk['choices'][0]['text']
|
text = completion_chunk['choices'][0]['text']
|
||||||
output += text
|
output += text
|
||||||
if callback:
|
if callback:
|
||||||
callback(text)
|
callback(text)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def generate_with_streaming(self, **kwargs):
|
def generate_with_streaming(self, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user