2023-09-18 19:30:05 +02:00
// A basic application simulating a server with multiple clients.
// The clients submite requests to the server and they are processed in parallel.
# include "build-info.h"
# include "common.h"
# include "llama.h"
# include <cmath>
# include <cstdio>
# include <string>
# include <vector>
// trim whitespace from the beginning and end of a string
static std : : string trim ( const std : : string & str ) {
size_t start = 0 ;
size_t end = str . size ( ) ;
while ( start < end & & isspace ( str [ start ] ) ) {
start + = 1 ;
}
while ( end > start & & isspace ( str [ end - 1 ] ) ) {
end - = 1 ;
}
return str . substr ( start , end - start ) ;
}
2023-09-19 12:54:41 +02:00
static std : : string k_system =
R " (Transcript of a never ending dialog, where the User interacts with an Assistant.
2023-09-18 19:30:05 +02:00
The Assistant is helpful , kind , honest , good at writing , and never fails to answer the User ' s requests immediately and with precision .
2023-09-19 22:34:30 +02:00
User : Recommend a nice restaurant in the area .
Assistant : I recommend the restaurant " The Golden Duck " . It is a 5 star restaurant with a great view of the city . The food is delicious and the service is excellent . The prices are reasonable and the portions are generous . The restaurant is located at 123 Main Street , New York , NY 10001. The phone number is ( 212 ) 555 - 1234. The hours are Monday through Friday from 11 : 00 am to 10 : 00 pm . The restaurant is closed on Saturdays and Sundays .
User : Who is Richard Feynman ?
Assistant : Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics . He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics . He was a popular lecturer and author , and he wrote several books , including " Surely You're Joking, Mr. Feynman! " and " What Do You Care What Other People Think? " .
2023-09-19 16:00:42 +02:00
User : ) " ;
2023-09-18 19:30:05 +02:00
static std : : vector < std : : string > k_prompts = {
" What is the meaning of life? " ,
" Tell me an interesting fact about llamas. " ,
" What is the best way to cook a steak? " ,
" Are you familiar with the Special Theory of Relativity and can you explain it to me? " ,
" Recommend some interesting books to read. " ,
" What is the best way to learn a new language? " ,
" How to get a job at Google? " ,
" If you could have any superpower, what would it be? " ,
" I want to learn how to play the piano. " ,
} ;
struct client {
int32_t id = 0 ;
llama_seq_id seq_id = - 1 ;
llama_token sampled ;
2023-09-18 20:34:20 +02:00
int64_t t_start_prompt ;
int64_t t_start_gen ;
2023-09-18 19:30:05 +02:00
int32_t n_prompt = 0 ;
int32_t n_decoded = 0 ;
int32_t i_batch = - 1 ;
std : : string input ;
std : : string prompt ;
std : : string response ;
2023-09-19 16:00:42 +02:00
std : : vector < llama_token > tokens_prev ;
2023-09-18 19:30:05 +02:00
} ;
int main ( int argc , char * * argv ) {
2023-09-19 22:34:30 +02:00
srand ( 1234 ) ;
2023-09-18 19:30:05 +02:00
gpt_params params ;
if ( gpt_params_parse ( argc , argv , params ) = = false ) {
return 1 ;
}
2023-09-19 16:00:42 +02:00
// number of simultaneous "clients" to simulate
const int32_t n_clients = params . n_parallel ;
2023-09-19 11:29:37 +02:00
// requests to simulate
2023-09-19 16:00:42 +02:00
const int32_t n_seq = params . n_sequences ;
// insert new requests as soon as the previous one is done
2023-09-20 08:24:02 +02:00
const bool cont_batching = params . cont_batching ;
2023-09-18 23:24:13 +02:00
2023-09-18 19:30:05 +02:00
# ifndef LOG_DISABLE_LOGS
log_set_target ( log_filename_generator ( " parallel " , " log " ) ) ;
LOG_TEE ( " Log start \n " ) ;
log_dump_cmdline ( argc , argv ) ;
# endif // LOG_DISABLE_LOGS
// init llama.cpp
llama_backend_init ( params . numa ) ;
llama_model * model = NULL ;
llama_context * ctx = NULL ;
// load the target model
params . logits_all = true ;
std : : tie ( model , ctx ) = llama_init_from_gpt_params ( params ) ;
fprintf ( stderr , " \n \n " ) ;
fflush ( stderr ) ;
const int n_ctx = llama_n_ctx ( ctx ) ;
const int n_vocab = llama_n_vocab ( ctx ) ;
std : : vector < client > clients ( n_clients ) ;
for ( size_t i = 0 ; i < clients . size ( ) ; + + i ) {
auto & client = clients [ i ] ;
client . id = i ;
2023-09-20 16:32:21 +02:00
client . tokens_prev . resize ( params . n_predict ) ;
2023-09-19 16:00:42 +02:00
std : : fill ( client . tokens_prev . begin ( ) , client . tokens_prev . end ( ) , 0 ) ;
2023-09-18 19:30:05 +02:00
}
std : : vector < llama_token_data > candidates ;
candidates . reserve ( n_vocab ) ;
2023-09-19 16:00:42 +02:00
std : : vector < llama_token > tokens_system ;
tokens_system = : : llama_tokenize ( ctx , k_system , true ) ;
2023-09-20 12:06:34 +02:00
const int32_t n_tokens_system = tokens_system . size ( ) ;
2023-09-19 16:00:42 +02:00
2023-09-18 19:30:05 +02:00
llama_seq_id g_seq_id = 0 ;
2023-09-20 18:09:25 +02:00
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init ( params . n_ctx , 0 ) ;
2023-09-18 19:30:05 +02:00
2023-09-18 23:24:13 +02:00
int32_t n_total_prompt = 0 ;
int32_t n_total_gen = 0 ;
2023-09-19 22:47:47 +02:00
int32_t n_cache_miss = 0 ;
2023-09-18 23:24:13 +02:00
2023-09-19 11:29:37 +02:00
const auto t_main_start = ggml_time_us ( ) ;
2023-09-18 23:24:13 +02:00
2023-09-19 16:00:42 +02:00
LOG_TEE ( " %s: Simulating parallel requests from clients: \n " , __func__ ) ;
2023-09-20 08:24:02 +02:00
LOG_TEE ( " %s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d \n " , __func__ , n_clients , n_seq , cont_batching , n_tokens_system ) ;
2023-09-19 16:00:42 +02:00
LOG_TEE ( " \n " ) ;
{
LOG_TEE ( " %s: Evaluating the system prompt ... \n " , __func__ ) ;
2023-09-20 09:46:18 +02:00
batch . n_tokens = n_tokens_system ;
2023-09-19 16:00:42 +02:00
2023-09-20 12:06:34 +02:00
for ( int32_t i = 0 ; i < batch . n_tokens ; + + i ) {
2023-09-20 09:46:18 +02:00
batch . token [ i ] = tokens_system [ i ] ;
batch . pos [ i ] = i ;
batch . seq_id [ i ] = 0 ;
batch . logits [ i ] = false ;
2023-09-19 16:00:42 +02:00
}
if ( llama_decode ( ctx , batch , params . n_threads ) ! = 0 ) {
LOG_TEE ( " %s: llama_decode() failed \n " , __func__ ) ;
return 1 ;
}
2023-09-20 12:06:34 +02:00
// assign the system KV cache to all parallel sequences
2023-09-19 16:00:42 +02:00
for ( int32_t i = 1 ; i < n_clients ; + + i ) {
llama_kv_cache_seq_cp ( ctx , 0 , i , 0 , n_tokens_system ) ;
}
LOG_TEE ( " \n " ) ;
}
2023-09-19 22:34:30 +02:00
LOG_TEE ( " Processing requests ... \n \n " ) ;
2023-09-19 11:29:37 +02:00
while ( true ) {
2023-09-20 09:46:18 +02:00
batch . n_tokens = 0 ;
2023-09-18 19:30:05 +02:00
2023-09-20 09:46:18 +02:00
// decode any currently ongoing sequences
2023-09-18 19:30:05 +02:00
for ( auto & client : clients ) {
if ( client . seq_id = = - 1 ) {
2023-09-18 20:34:20 +02:00
continue ;
}
2023-09-20 09:46:18 +02:00
batch . token [ batch . n_tokens ] = client . sampled ;
batch . pos [ batch . n_tokens ] = n_tokens_system + client . n_prompt + client . n_decoded ;
batch . seq_id [ batch . n_tokens ] = client . id ;
batch . logits [ batch . n_tokens ] = true ;
2023-09-18 20:34:20 +02:00
client . n_decoded + = 1 ;
2023-09-20 09:46:18 +02:00
client . i_batch = batch . n_tokens ;
batch . n_tokens + = 1 ;
2023-09-18 20:34:20 +02:00
}
2023-09-20 09:46:18 +02:00
if ( batch . n_tokens = = 0 ) {
2023-09-18 20:34:20 +02:00
// all sequences have ended - clear the entire KV cache
2023-09-19 16:00:42 +02:00
for ( int i = 0 ; i < n_clients ; + + i ) {
llama_kv_cache_seq_rm ( ctx , i , n_tokens_system , - 1 ) ;
}
2023-09-20 16:32:21 +02:00
LOG_TEE ( " %s: clearing the KV cache \n " , __func__ ) ;
2023-09-18 23:24:13 +02:00
}
2023-09-18 20:34:20 +02:00
2023-09-20 09:46:18 +02:00
// insert new sequences for decoding
if ( cont_batching | | batch . n_tokens = = 0 ) {
2023-09-18 20:34:20 +02:00
for ( auto & client : clients ) {
2023-09-19 11:29:37 +02:00
if ( client . seq_id = = - 1 & & g_seq_id < n_seq ) {
2023-09-19 22:34:30 +02:00
client . seq_id = g_seq_id ;
2023-09-20 09:46:18 +02:00
2023-09-18 20:34:20 +02:00
client . t_start_prompt = ggml_time_us ( ) ;
2023-09-19 11:29:37 +02:00
client . t_start_gen = 0 ;
2023-09-18 20:34:20 +02:00
2023-09-20 09:46:18 +02:00
client . input = k_prompts [ rand ( ) % k_prompts . size ( ) ] ;
client . prompt = client . input + " \n Assistant: " ;
2023-09-18 20:34:20 +02:00
client . response = " " ;
2023-09-20 09:46:18 +02:00
2023-09-19 16:00:42 +02:00
std : : fill ( client . tokens_prev . begin ( ) , client . tokens_prev . end ( ) , 0 ) ;
2023-09-18 20:34:20 +02:00
2023-09-20 16:32:21 +02:00
// do not prepend BOS because we have a system prompt!
2023-09-19 16:00:42 +02:00
std : : vector < llama_token > tokens_prompt ;
2023-09-20 16:32:21 +02:00
tokens_prompt = : : llama_tokenize ( ctx , client . prompt , false ) ;
2023-09-18 20:34:20 +02:00
2023-09-19 16:00:42 +02:00
for ( size_t i = 0 ; i < tokens_prompt . size ( ) ; + + i ) {
2023-09-20 09:46:18 +02:00
batch . token [ batch . n_tokens ] = tokens_prompt [ i ] ;
batch . pos [ batch . n_tokens ] = i + n_tokens_system ;
batch . seq_id [ batch . n_tokens ] = client . id ;
batch . logits [ batch . n_tokens ] = false ;
batch . n_tokens + = 1 ;
}
// extract the logits only for the last token
if ( batch . n_tokens > 0 ) {
batch . logits [ batch . n_tokens - 1 ] = true ;
2023-09-18 20:34:20 +02:00
}
2023-09-18 23:24:13 +02:00
2023-09-19 16:00:42 +02:00
client . n_prompt = tokens_prompt . size ( ) ;
2023-09-19 11:29:37 +02:00
client . n_decoded = 0 ;
2023-09-20 09:46:18 +02:00
client . i_batch = batch . n_tokens - 1 ;
LOG_TEE ( " \033 [1mClient %3d, seq %4d, started decoding ... \033 [0m \n " , client . id , client . seq_id ) ;
2023-09-18 20:34:20 +02:00
g_seq_id + = 1 ;
2023-09-20 09:46:18 +02:00
// insert new requests one-by-one
2023-09-20 08:24:02 +02:00
//if (cont_batching) {
// break;
//}
2023-09-18 19:30:05 +02:00
}
}
}
2023-09-20 09:46:18 +02:00
if ( batch . n_tokens = = 0 ) {
2023-09-19 11:29:37 +02:00
break ;
}
2023-09-18 19:30:05 +02:00
// process in chunks of params.n_batch
2023-09-19 12:21:36 +02:00
int32_t n_batch = params . n_batch ;
2023-09-20 09:46:18 +02:00
for ( int32_t i = 0 ; i < ( int32_t ) batch . n_tokens ; i + = n_batch ) {
2023-09-20 12:06:34 +02:00
const int32_t n_tokens = std : : min ( n_batch , ( int32_t ) ( batch . n_tokens - i ) ) ;
2023-09-18 19:30:05 +02:00
2023-09-20 09:46:18 +02:00
llama_batch batch_view = {
2023-09-18 19:30:05 +02:00
n_tokens ,
2023-09-20 09:46:18 +02:00
batch . token + i ,
2023-09-18 19:30:05 +02:00
nullptr ,
2023-09-20 09:46:18 +02:00
batch . pos + i ,
batch . seq_id + i ,
batch . logits + i ,
2023-09-18 19:30:05 +02:00
0 , 0 , 0 , // unused
} ;
2023-09-20 09:46:18 +02:00
const int ret = llama_decode ( ctx , batch_view , params . n_threads ) ;
2023-09-19 16:00:42 +02:00
if ( ret ! = 0 ) {
if ( n_batch = = 1 | | ret < 0 ) {
2023-09-20 09:46:18 +02:00
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE ( " %s : failed to decode the batch, n_batch = %d, ret = %d \n " , __func__ , n_batch , ret ) ;
2023-09-19 12:21:36 +02:00
return 1 ;
}
2023-09-20 09:46:18 +02:00
LOG ( " %s : failed to decode the batch, retrying with n_batch = %d \n " , __func__ , n_batch / 2 ) ;
2023-09-19 12:21:36 +02:00
2023-09-19 22:47:47 +02:00
n_cache_miss + = 1 ;
2023-09-19 12:21:36 +02:00
// retry with half the batch size to try to find a free slot in the KV cache
n_batch / = 2 ;
i - = n_batch ;
continue ;
2023-09-18 19:30:05 +02:00
}
2023-09-19 12:29:29 +02:00
LOG ( " %s : decoded batch of %d tokens \n " , __func__ , n_tokens ) ;
2023-09-19 12:21:36 +02:00
2023-09-18 19:30:05 +02:00
for ( auto & client : clients ) {
if ( client . i_batch < ( int ) i | | client . i_batch > = ( int ) ( i + n_tokens ) ) {
continue ;
}
2023-09-18 21:00:02 +02:00
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
2023-09-19 16:00:42 +02:00
const llama_token id = llama_sample_token ( ctx , NULL , NULL , params , client . tokens_prev , candidates , client . i_batch - i ) ;
2023-09-18 19:30:05 +02:00
2023-09-19 11:29:37 +02:00
if ( client . n_decoded = = 1 ) {
// start measuring generation time after the first token to make sure all concurrent clients
// have their prompt already processed
2023-09-18 20:34:20 +02:00
client . t_start_gen = ggml_time_us ( ) ;
}
2023-09-18 19:30:05 +02:00
// remember which tokens were sampled - used for repetition penalties during sampling
2023-09-19 16:00:42 +02:00
client . tokens_prev . erase ( client . tokens_prev . begin ( ) ) ;
client . tokens_prev . push_back ( id ) ;
2023-09-18 19:30:05 +02:00
const std : : string token_str = llama_token_to_piece ( ctx , id ) ;
client . response + = token_str ;
client . sampled = id ;
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
2023-09-19 11:29:37 +02:00
if ( client . n_decoded > 2 & &
2023-09-19 12:29:29 +02:00
( id = = llama_token_eos ( ctx ) | | client . n_decoded + client . n_prompt > = params . n_predict | |
2023-09-19 11:29:37 +02:00
client . response . find ( " User: " ) ! = std : : string : : npos | |
client . response . find ( ' \n ' ) ! = std : : string : : npos ) ) {
2023-09-18 20:34:20 +02:00
// basic reverse prompt
2023-09-18 19:30:05 +02:00
const size_t pos = client . response . find ( " User: " ) ;
if ( pos ! = std : : string : : npos ) {
client . response = client . response . substr ( 0 , pos ) ;
}
2023-09-19 16:00:42 +02:00
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
2023-09-19 22:34:30 +02:00
llama_kv_cache_seq_rm ( ctx , client . id , n_tokens_system , n_ctx ) ;
2023-09-18 19:30:05 +02:00
const auto t_main_end = ggml_time_us ( ) ;
2023-09-19 23:35:10 +02:00
LOG_TEE ( " \033 [1mClient %3d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033 [0m \n \n Input: %s \n Response: %s \n \n " ,
2023-09-19 11:29:37 +02:00
client . id , client . seq_id , client . n_prompt , client . n_decoded ,
2023-09-18 23:24:13 +02:00
( t_main_end - client . t_start_prompt ) / 1e6 ,
2023-09-19 23:35:10 +02:00
( double ) ( client . n_prompt + client . n_decoded ) / ( t_main_end - client . t_start_prompt ) * 1e6 ,
2023-09-19 22:50:05 +02:00
n_cache_miss ,
2023-09-18 20:34:20 +02:00
: : trim ( client . input ) . c_str ( ) ,
: : trim ( client . response ) . c_str ( ) ) ;
2023-09-18 19:30:05 +02:00
2023-09-18 23:24:13 +02:00
n_total_prompt + = client . n_prompt ;
2023-09-19 11:29:37 +02:00
n_total_gen + = client . n_decoded ;
2023-09-18 23:24:13 +02:00
2023-09-18 19:30:05 +02:00
client . seq_id = - 1 ;
}
2023-09-18 20:34:20 +02:00
client . i_batch = - 1 ;
2023-09-18 19:30:05 +02:00
}
}
}
2023-09-19 11:29:37 +02:00
const auto t_main_end = ggml_time_us ( ) ;
2023-09-18 23:24:13 +02:00
LOG_TEE ( " \n \n " ) ;
2023-09-19 11:29:37 +02:00
LOG_TEE ( " Total prompt tokens: %6d, speed: %5.2f t/s \n " , n_total_prompt , ( double ) ( n_total_prompt ) / ( t_main_end - t_main_start ) * 1e6 ) ;
LOG_TEE ( " Total gen tokens: %6d, speed: %5.2f t/s \n " , n_total_gen , ( double ) ( n_total_gen ) / ( t_main_end - t_main_start ) * 1e6 ) ;
LOG_TEE ( " Total speed (AVG): %6s speed: %5.2f t/s \n " , " " , ( double ) ( n_total_prompt + n_total_gen ) / ( t_main_end - t_main_start ) * 1e6 ) ;
2023-09-19 22:47:47 +02:00
LOG_TEE ( " Cache misses: %6d \n " , n_cache_miss ) ;
2023-09-18 23:24:13 +02:00
2023-09-18 19:30:05 +02:00
LOG_TEE ( " \n \n " ) ;
llama_print_timings ( ctx ) ;
2023-09-20 09:46:18 +02:00
llama_batch_free ( batch ) ;
2023-09-18 19:30:05 +02:00
llama_free ( ctx ) ;
llama_free_model ( model ) ;
llama_backend_free ( ) ;
fprintf ( stderr , " \n \n " ) ;
return 0 ;
}