2024-07-05 04:15:37 +02:00
import importlib
2024-07-13 05:04:19 +02:00
import platform
2024-07-23 03:05:11 +02:00
from typing import Sequence
2024-09-30 23:04:21 +02:00
import numpy as np
2024-07-23 03:05:11 +02:00
from tqdm import tqdm
2024-02-08 06:40:58 +01:00
2024-03-09 04:25:33 +01:00
from modules import shared
from modules . cache_utils import process_llamacpp_cache
2024-07-05 04:43:34 +02:00
imported_module = None
2024-07-05 04:15:37 +02:00
def llama_cpp_lib ( ) :
2024-07-05 04:43:34 +02:00
global imported_module
2024-07-13 05:04:19 +02:00
# Determine the platform
is_macos = platform . system ( ) == ' Darwin '
# Define the library names based on the platform
if is_macos :
lib_names = [
( None , ' llama_cpp ' )
]
else :
lib_names = [
( ' cpu ' , ' llama_cpp ' ) ,
2024-07-23 03:05:11 +02:00
( ' tensorcores ' , ' llama_cpp_cuda_tensorcores ' ) ,
2024-07-13 05:04:19 +02:00
( None , ' llama_cpp_cuda ' ) ,
( None , ' llama_cpp ' )
]
for arg , lib_name in lib_names :
should_import = ( arg is None or getattr ( shared . args , arg ) )
if should_import :
if imported_module and imported_module != lib_name :
# Conflict detected, raise an exception
raise Exception ( f " Cannot import ` { lib_name } ` because ` { imported_module } ` is already imported. Switching to a different version of llama-cpp-python currently requires a server restart. " )
try :
return_lib = importlib . import_module ( lib_name )
imported_module = lib_name
monkey_patch_llama_cpp_python ( return_lib )
return return_lib
except ImportError :
continue
return None
2024-04-30 14:11:31 +02:00
2024-02-08 06:40:58 +01:00
2024-07-23 03:05:11 +02:00
def eval_with_progress ( self , tokens : Sequence [ int ] ) :
"""
A copy of
https : / / github . com / abetlen / llama - cpp - python / blob / main / llama_cpp / llama . py
with tqdm to show prompt processing progress .
"""
self . _ctx . kv_cache_seq_rm ( - 1 , self . n_tokens , - 1 )
2024-09-04 02:37:06 +02:00
if len ( tokens ) > self . n_batch :
2024-07-23 03:05:11 +02:00
progress_bar = tqdm ( range ( 0 , len ( tokens ) , self . n_batch ) , desc = " Prompt evaluation " , leave = False )
else :
progress_bar = range ( 0 , len ( tokens ) , self . n_batch )
for i in progress_bar :
batch = tokens [ i : min ( len ( tokens ) , i + self . n_batch ) ]
n_past = self . n_tokens
n_tokens = len ( batch )
self . _batch . set_batch (
batch = batch , n_past = n_past , logits_all = self . context_params . logits_all
)
self . _ctx . decode ( self . _batch )
# Save tokens
self . input_ids [ n_past : n_past + n_tokens ] = batch
# Save logits
if self . context_params . logits_all :
rows = n_tokens
cols = self . _n_vocab
2024-09-30 23:04:21 +02:00
logits = np . ctypeslib . as_array (
self . _ctx . get_logits ( ) , shape = ( rows * cols , )
)
self . scores [ n_past : n_past + n_tokens , : ] . reshape ( - 1 ) [ : : ] = logits
self . last_updated_index = n_past + n_tokens - 1
2024-07-23 03:05:11 +02:00
else :
rows = 1
cols = self . _n_vocab
2024-09-30 23:04:21 +02:00
logits = np . ctypeslib . as_array (
self . _ctx . get_logits ( ) , shape = ( rows * cols , )
)
last_token_index = min ( n_past + n_tokens - 1 , self . scores . shape [ 0 ] - 1 )
self . scores [ last_token_index , : ] = logits . reshape ( - 1 )
self . last_updated_index = last_token_index
2024-07-23 03:05:11 +02:00
# Update n_tokens
self . n_tokens + = n_tokens
2024-07-05 04:15:37 +02:00
def monkey_patch_llama_cpp_python ( lib ) :
2024-07-05 12:34:15 +02:00
if getattr ( lib . Llama , ' _is_patched ' , False ) :
# If the patch is already applied, do nothing
return
2024-03-09 04:25:33 +01:00
def my_generate ( self , * args , * * kwargs ) :
if shared . args . streaming_llm :
new_sequence = args [ 0 ]
past_sequence = self . _input_ids
# Do the cache trimming for StreamingLLM
process_llamacpp_cache ( self , new_sequence , past_sequence )
for output in self . original_generate ( * args , * * kwargs ) :
yield output
2024-07-23 03:05:11 +02:00
lib . Llama . eval = eval_with_progress
2024-03-09 04:25:33 +01:00
lib . Llama . original_generate = lib . Llama . generate
lib . Llama . generate = my_generate
2024-07-05 12:34:15 +02:00
# Set the flag to indicate that the patch has been applied
lib . Llama . _is_patched = True