2023-05-16 01:19:55 +02:00
|
|
|
import re
|
2023-06-20 02:31:19 +02:00
|
|
|
from functools import partial
|
2023-05-16 01:19:55 +02:00
|
|
|
|
2023-09-17 15:42:32 +02:00
|
|
|
import numpy as np
|
2023-07-20 04:31:19 +02:00
|
|
|
import torch
|
2023-03-19 07:42:10 +01:00
|
|
|
|
2024-07-05 04:15:37 +02:00
|
|
|
from modules import shared
|
2023-03-31 19:27:01 +02:00
|
|
|
from modules.callbacks import Iteratorize
|
2024-07-05 04:15:37 +02:00
|
|
|
from modules.llama_cpp_python_hijack import llama_cpp_lib
|
2023-05-22 03:42:34 +02:00
|
|
|
from modules.logging_colors import logger
|
2023-08-04 01:01:15 +02:00
|
|
|
from modules.text_generation import get_max_prompt_length
|
2023-03-31 19:27:01 +02:00
|
|
|
|
2024-12-17 21:43:48 +01:00
|
|
|
llamacpp_quant_mapping = {
|
|
|
|
'f32': 0,
|
|
|
|
'fp16': 1,
|
|
|
|
'q4_0': 2,
|
|
|
|
'q4_1': 3,
|
|
|
|
'q5_0': 6,
|
|
|
|
'q5_1': 7,
|
|
|
|
'q8_0': 8,
|
|
|
|
'q8_1': 9,
|
|
|
|
'q2_k': 10,
|
|
|
|
'q3_k': 11,
|
|
|
|
'q4_k': 12,
|
|
|
|
'q5_k': 13,
|
|
|
|
'q6_k': 14,
|
|
|
|
'q8_k': 15,
|
|
|
|
'iq4_nl': 20,
|
|
|
|
'bf16': 30,
|
|
|
|
}
|
|
|
|
|
|
|
|
llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'}
|
|
|
|
|
|
|
|
|
|
|
|
def get_llamacpp_cache_type_for_string(quant_type: str):
|
|
|
|
quant_type = quant_type.lower()
|
|
|
|
if quant_type in llamacpp_valid_cache_types:
|
|
|
|
return llamacpp_quant_mapping[quant_type]
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.")
|
|
|
|
|
2023-03-19 07:42:10 +01:00
|
|
|
|
2023-06-20 02:31:19 +02:00
|
|
|
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
|
|
|
logits[eos_token] = -float('inf')
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2023-09-15 23:27:27 +02:00
|
|
|
def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
|
|
|
for token_id in token_ids:
|
|
|
|
logits[token_id] = -float('inf')
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2023-03-19 07:42:10 +01:00
|
|
|
class LlamaCppModel:
|
|
|
|
def __init__(self):
|
|
|
|
self.initialized = False
|
2023-09-24 23:05:24 +02:00
|
|
|
self.grammar_string = ''
|
2023-09-24 16:08:41 +02:00
|
|
|
self.grammar = None
|
2023-03-19 07:42:10 +01:00
|
|
|
|
2023-05-22 03:42:34 +02:00
|
|
|
def __del__(self):
|
2023-11-30 00:19:48 +01:00
|
|
|
del self.model
|
2023-05-16 00:51:23 +02:00
|
|
|
|
2023-03-19 07:42:10 +01:00
|
|
|
@classmethod
|
|
|
|
def from_pretrained(self, path):
|
2023-11-17 14:14:25 +01:00
|
|
|
|
2024-04-30 14:11:31 +02:00
|
|
|
Llama = llama_cpp_lib().Llama
|
|
|
|
LlamaCache = llama_cpp_lib().LlamaCache
|
2023-11-17 14:14:25 +01:00
|
|
|
|
2023-03-19 07:42:10 +01:00
|
|
|
result = self()
|
2023-05-16 01:19:55 +02:00
|
|
|
cache_capacity = 0
|
|
|
|
if shared.args.cache_capacity is not None:
|
|
|
|
if 'GiB' in shared.args.cache_capacity:
|
|
|
|
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
|
|
|
|
elif 'MiB' in shared.args.cache_capacity:
|
|
|
|
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
|
|
|
|
else:
|
|
|
|
cache_capacity = int(shared.args.cache_capacity)
|
|
|
|
|
2023-11-25 15:33:37 +01:00
|
|
|
if cache_capacity > 0:
|
|
|
|
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
2023-08-18 17:03:34 +02:00
|
|
|
|
|
|
|
if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '':
|
|
|
|
tensor_split_list = None
|
|
|
|
else:
|
|
|
|
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
|
|
|
|
2023-05-02 23:25:28 +02:00
|
|
|
params = {
|
|
|
|
'model_path': str(path),
|
2023-05-25 15:29:31 +02:00
|
|
|
'n_ctx': shared.args.n_ctx,
|
2023-05-02 23:25:28 +02:00
|
|
|
'n_threads': shared.args.threads or None,
|
2023-10-02 06:27:04 +02:00
|
|
|
'n_threads_batch': shared.args.threads_batch or None,
|
2023-05-02 23:25:28 +02:00
|
|
|
'n_batch': shared.args.n_batch,
|
|
|
|
'use_mmap': not shared.args.no_mmap,
|
2023-05-15 03:58:11 +02:00
|
|
|
'use_mlock': shared.args.mlock,
|
2023-10-22 21:22:06 +02:00
|
|
|
'mul_mat_q': not shared.args.no_mul_mat_q,
|
2023-09-27 03:05:00 +02:00
|
|
|
'numa': shared.args.numa,
|
2023-07-18 03:32:37 +02:00
|
|
|
'n_gpu_layers': shared.args.n_gpu_layers,
|
2024-06-24 07:09:24 +02:00
|
|
|
'rope_freq_base': shared.args.rope_freq_base,
|
2023-08-18 17:03:34 +02:00
|
|
|
'tensor_split': tensor_split_list,
|
2023-07-18 03:32:37 +02:00
|
|
|
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
2024-02-05 03:36:40 +01:00
|
|
|
'offload_kqv': not shared.args.no_offload_kqv,
|
2024-05-03 13:31:22 +02:00
|
|
|
'split_mode': 1 if not shared.args.row_split else 2,
|
|
|
|
'flash_attn': shared.args.flash_attn
|
2023-05-02 23:25:28 +02:00
|
|
|
}
|
2023-08-27 07:11:07 +02:00
|
|
|
|
2024-12-17 21:43:48 +01:00
|
|
|
if shared.args.cache_type:
|
|
|
|
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
|
|
|
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
2024-06-29 18:10:33 +02:00
|
|
|
|
2023-11-17 14:14:25 +01:00
|
|
|
result.model = Llama(**params)
|
2023-05-16 01:19:55 +02:00
|
|
|
if cache_capacity > 0:
|
2023-11-17 14:14:25 +01:00
|
|
|
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
2023-05-02 23:25:28 +02:00
|
|
|
|
|
|
|
# This is ugly, but the model and the tokenizer are the same object in this library.
|
|
|
|
return result, result
|
|
|
|
|
|
|
|
def encode(self, string):
|
|
|
|
if type(string) is str:
|
|
|
|
string = string.encode()
|
2023-06-06 18:06:05 +02:00
|
|
|
|
2023-05-02 23:25:28 +02:00
|
|
|
return self.model.tokenize(string)
|
2023-03-19 07:42:10 +01:00
|
|
|
|
2023-11-08 04:05:36 +01:00
|
|
|
def decode(self, ids, **kwargs):
|
2023-09-17 16:01:34 +02:00
|
|
|
return self.model.detokenize(ids).decode('utf-8')
|
2023-07-07 18:11:30 +02:00
|
|
|
|
2023-09-17 15:42:32 +02:00
|
|
|
def get_logits(self, tokens):
|
2023-11-30 20:21:40 +01:00
|
|
|
self.model.reset()
|
2023-09-17 15:42:32 +02:00
|
|
|
self.model.eval(tokens)
|
|
|
|
logits = self.model._scores
|
|
|
|
logits = np.expand_dims(logits, 0) # batch dim is expected
|
|
|
|
return torch.tensor(logits, dtype=torch.float32)
|
|
|
|
|
2023-09-24 23:05:24 +02:00
|
|
|
def load_grammar(self, string):
|
|
|
|
if string != self.grammar_string:
|
|
|
|
self.grammar_string = string
|
|
|
|
if string.strip() != '':
|
2024-04-30 14:11:31 +02:00
|
|
|
self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string)
|
2023-09-24 16:08:41 +02:00
|
|
|
else:
|
|
|
|
self.grammar = None
|
|
|
|
|
2023-06-17 01:35:38 +02:00
|
|
|
def generate(self, prompt, state, callback=None):
|
2024-04-30 14:11:31 +02:00
|
|
|
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
|
2023-06-17 01:35:38 +02:00
|
|
|
prompt = prompt if type(prompt) is str else prompt.decode()
|
2023-08-04 01:01:15 +02:00
|
|
|
|
|
|
|
# Handle truncation
|
|
|
|
prompt = self.encode(prompt)
|
|
|
|
prompt = prompt[-get_max_prompt_length(state):]
|
2023-09-17 22:07:48 +02:00
|
|
|
prompt = self.decode(prompt)
|
2023-08-04 01:01:15 +02:00
|
|
|
|
2023-09-24 23:05:24 +02:00
|
|
|
self.load_grammar(state['grammar_string'])
|
2023-09-15 23:27:27 +02:00
|
|
|
logit_processors = LogitsProcessorList()
|
|
|
|
if state['ban_eos_token']:
|
2023-09-18 17:15:02 +02:00
|
|
|
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
|
2023-09-15 23:27:27 +02:00
|
|
|
|
|
|
|
if state['custom_token_bans']:
|
|
|
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
|
|
|
if len(to_ban) > 0:
|
|
|
|
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))
|
|
|
|
|
2023-05-16 01:19:55 +02:00
|
|
|
completion_chunks = self.model.create_completion(
|
2023-06-17 01:35:38 +02:00
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=state['max_new_tokens'],
|
|
|
|
temperature=state['temperature'],
|
2024-10-14 22:05:51 +02:00
|
|
|
top_p=state['top_p'] if state['top_p'] < 1 else 0.999,
|
2023-11-22 00:59:39 +01:00
|
|
|
min_p=state['min_p'],
|
|
|
|
typical_p=state['typical_p'],
|
2023-11-17 14:14:25 +01:00
|
|
|
frequency_penalty=state['frequency_penalty'],
|
2023-11-18 04:31:27 +01:00
|
|
|
presence_penalty=state['presence_penalty'],
|
|
|
|
repeat_penalty=state['repetition_penalty'],
|
|
|
|
top_k=state['top_k'],
|
|
|
|
stream=True,
|
|
|
|
seed=int(state['seed']) if state['seed'] != -1 else None,
|
2023-06-18 00:08:25 +02:00
|
|
|
tfs_z=state['tfs'],
|
2023-06-17 01:35:38 +02:00
|
|
|
mirostat_mode=int(state['mirostat_mode']),
|
|
|
|
mirostat_tau=state['mirostat_tau'],
|
|
|
|
mirostat_eta=state['mirostat_eta'],
|
2023-09-15 23:27:27 +02:00
|
|
|
logits_processor=logit_processors,
|
2023-09-24 16:08:41 +02:00
|
|
|
grammar=self.grammar
|
2023-05-16 01:19:55 +02:00
|
|
|
)
|
2023-06-06 18:06:05 +02:00
|
|
|
|
2023-05-16 01:19:55 +02:00
|
|
|
output = ""
|
|
|
|
for completion_chunk in completion_chunks:
|
2023-08-19 05:17:27 +02:00
|
|
|
if shared.stop_everything:
|
|
|
|
break
|
2023-11-25 15:33:37 +01:00
|
|
|
|
2023-05-16 01:19:55 +02:00
|
|
|
text = completion_chunk['choices'][0]['text']
|
2023-05-02 23:25:28 +02:00
|
|
|
output += text
|
|
|
|
if callback:
|
2023-05-16 01:19:55 +02:00
|
|
|
callback(text)
|
2023-06-06 18:06:05 +02:00
|
|
|
|
2023-05-16 01:19:55 +02:00
|
|
|
return output
|
2023-03-19 07:42:10 +01:00
|
|
|
|
2023-06-17 01:35:38 +02:00
|
|
|
def generate_with_streaming(self, *args, **kwargs):
|
|
|
|
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
2023-03-31 19:27:01 +02:00
|
|
|
reply = ''
|
2023-03-19 07:42:10 +01:00
|
|
|
for token in generator:
|
|
|
|
reply += token
|
|
|
|
yield reply
|