2024-07-05 04:15:37 +02:00
import importlib
2024-02-08 06:40:58 +01:00
from typing import Sequence
from tqdm import tqdm
2024-03-09 04:25:33 +01:00
from modules import shared
from modules . cache_utils import process_llamacpp_cache
2024-04-30 14:11:31 +02:00
2024-07-05 04:43:34 +02:00
imported_module = None
2024-07-05 04:15:37 +02:00
def llama_cpp_lib ( ) :
2024-07-05 04:43:34 +02:00
global imported_module
2024-07-11 22:00:29 +02:00
def module_to_purpose ( module_name ) :
if module_name == ' llama_cpp ' :
return ' CPU '
elif module_name == ' llama_cpp_cuda_tensorcores ' :
return ' tensorcores '
elif module_name == ' llama_cpp_cuda ' :
return ' default '
return ' unknown '
2024-07-05 04:15:37 +02:00
return_lib = None
2024-04-30 14:11:31 +02:00
2024-07-05 04:15:37 +02:00
if shared . args . cpu :
2024-07-05 04:43:34 +02:00
if imported_module and imported_module != ' llama_cpp ' :
2024-07-11 22:00:29 +02:00
raise Exception ( f " The { module_to_purpose ( imported_module ) } version of llama-cpp-python is already loaded. Switching to the CPU version currently requires a server restart. " )
2024-07-05 04:15:37 +02:00
try :
return_lib = importlib . import_module ( ' llama_cpp ' )
2024-07-05 04:43:34 +02:00
imported_module = ' llama_cpp '
2024-07-05 04:15:37 +02:00
except :
pass
if shared . args . tensorcores and return_lib is None :
2024-07-05 04:43:34 +02:00
if imported_module and imported_module != ' llama_cpp_cuda_tensorcores ' :
2024-07-11 22:00:29 +02:00
raise Exception ( f " The { module_to_purpose ( imported_module ) } version of llama-cpp-python is already loaded. Switching to the tensorcores version currently requires a server restart. " )
2024-07-05 04:15:37 +02:00
try :
return_lib = importlib . import_module ( ' llama_cpp_cuda_tensorcores ' )
2024-07-05 04:43:34 +02:00
imported_module = ' llama_cpp_cuda_tensorcores '
2024-07-05 04:15:37 +02:00
except :
pass
if return_lib is None :
2024-07-05 04:43:34 +02:00
if imported_module and imported_module != ' llama_cpp_cuda ' :
2024-07-11 22:00:29 +02:00
raise Exception ( f " The { module_to_purpose ( imported_module ) } version of llama-cpp-python is already loaded. Switching to the default version currently requires a server restart. " )
2024-07-05 04:15:37 +02:00
try :
return_lib = importlib . import_module ( ' llama_cpp_cuda ' )
2024-07-05 04:43:34 +02:00
imported_module = ' llama_cpp_cuda '
2024-07-05 04:15:37 +02:00
except :
pass
if return_lib is None and not shared . args . cpu :
2024-07-05 04:43:34 +02:00
if imported_module and imported_module != ' llama_cpp ' :
2024-07-11 22:00:29 +02:00
raise Exception ( f " The { module_to_purpose ( imported_module ) } version of llama-cpp-python is already loaded. Switching to the CPU version currently requires a server restart. " )
2024-07-05 04:15:37 +02:00
try :
return_lib = importlib . import_module ( ' llama_cpp ' )
2024-07-05 04:43:34 +02:00
imported_module = ' llama_cpp '
2024-07-05 04:15:37 +02:00
except :
pass
if return_lib is not None :
monkey_patch_llama_cpp_python ( return_lib )
return return_lib
2024-04-30 14:11:31 +02:00
2024-02-08 06:40:58 +01:00
def eval_with_progress ( self , tokens : Sequence [ int ] ) :
"""
A copy of
https : / / github . com / abetlen / llama - cpp - python / blob / main / llama_cpp / llama . py
with tqdm to show prompt processing progress .
"""
assert self . _ctx . ctx is not None
assert self . _batch . batch is not None
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
if len ( tokens ) > 1 :
progress_bar = tqdm ( range ( 0 , len ( tokens ) , self . n_batch ) , desc = " Prompt evaluation " , leave = False )
else :
progress_bar = range ( 0 , len ( tokens ) , self . n_batch )
for i in progress_bar :
2024-04-11 23:15:34 +02:00
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
2024-02-08 06:40:58 +01:00
n_past = self . n_tokens
n_tokens = len ( batch )
self . _batch . set_batch (
batch = batch , n_past = n_past , logits_all = self . context_params . logits_all
)
self . _ctx . decode ( self . _batch )
# Save tokens
2024-04-11 23:15:34 +02:00
self . input_ids [ n_past : n_past + n_tokens ] = batch
2024-02-08 06:40:58 +01:00
# Save logits
2024-04-11 23:15:34 +02:00
if self . context_params . logits_all :
rows = n_tokens
cols = self . _n_vocab
logits = self . _ctx . get_logits ( ) [ : rows * cols ]
self . scores [ n_past : n_past + n_tokens , : ] . reshape ( - 1 ) [ : : ] = logits
else :
rows = 1
cols = self . _n_vocab
logits = self . _ctx . get_logits ( ) [ : rows * cols ]
self . scores [ n_past + n_tokens - 1 , : ] . reshape ( - 1 ) [ : : ] = logits
2024-02-08 06:40:58 +01:00
# Update n_tokens
self . n_tokens + = n_tokens
2024-07-05 04:15:37 +02:00
def monkey_patch_llama_cpp_python ( lib ) :
2024-07-05 12:34:15 +02:00
if getattr ( lib . Llama , ' _is_patched ' , False ) :
# If the patch is already applied, do nothing
return
2024-03-09 04:25:33 +01:00
def my_generate ( self , * args , * * kwargs ) :
if shared . args . streaming_llm :
new_sequence = args [ 0 ]
past_sequence = self . _input_ids
# Do the cache trimming for StreamingLLM
process_llamacpp_cache ( self , new_sequence , past_sequence )
for output in self . original_generate ( * args , * * kwargs ) :
yield output
2024-07-05 04:15:37 +02:00
lib . Llama . eval = eval_with_progress
2024-03-09 04:25:33 +01:00
lib . Llama . original_generate = lib . Llama . generate
lib . Llama . generate = my_generate
2024-07-05 12:34:15 +02:00
# Set the flag to indicate that the patch has been applied
lib . Llama . _is_patched = True