llama.android : fix build (#9350)

This commit is contained in:
Georgi Gerganov 2024-09-08 00:33:50 +03:00 committed by GitHub
parent f12295b8a9
commit a5b5d9a101
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 15 deletions

View File

@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
return env->NewStringUTF(result.str().c_str()); return env->NewStringUTF(result.str().c_str());
} }
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
}
extern "C" extern "C"
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
return reinterpret_cast<jlong>(batch); return reinterpret_cast<jlong>(batch);
} }
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = true;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
return reinterpret_cast<jlong>(smpl);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
}
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
@ -380,14 +397,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
JNIEnv * env, JNIEnv * env,
jobject, jobject,
jlong context_pointer, jlong context_pointer,
jlong sampling_pointer,
jlong batch_pointer, jlong batch_pointer,
jlong sampler_pointer,
jint n_len, jint n_len,
jobject intvar_ncur jobject intvar_ncur
) { ) {
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer); const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context); const auto model = llama_get_model(context);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
@ -395,9 +412,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
// sample the most likely token // sample the most likely token
const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1); const auto new_token_id = llama_sampler_sample(sampler, context, -1);
llama_sampler_accept(sampling, new_token_id); llama_sampler_accept(sampler, new_token_id);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

View File

@ -45,8 +45,10 @@ class LLamaAndroid {
private external fun free_context(context: Long) private external fun free_context(context: Long)
private external fun backend_init(numa: Boolean) private external fun backend_init(numa: Boolean)
private external fun backend_free() private external fun backend_free()
private external fun free_batch(batch: Long)
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun free_batch(batch: Long)
private external fun new_sampler(): Long
private external fun free_sampler(sampler: Long)
private external fun bench_model( private external fun bench_model(
context: Long, context: Long,
model: Long, model: Long,
@ -69,6 +71,7 @@ class LLamaAndroid {
private external fun completion_loop( private external fun completion_loop(
context: Long, context: Long,
batch: Long, batch: Long,
sampler: Long,
nLen: Int, nLen: Int,
ncur: IntVar ncur: IntVar
): String? ): String?
@ -101,8 +104,11 @@ class LLamaAndroid {
val batch = new_batch(512, 0, 1) val batch = new_batch(512, 0, 1)
if (batch == 0L) throw IllegalStateException("new_batch() failed") if (batch == 0L) throw IllegalStateException("new_batch() failed")
val sampler = new_sampler()
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
Log.i(tag, "Loaded model $pathToModel") Log.i(tag, "Loaded model $pathToModel")
threadLocalState.set(State.Loaded(model, context, batch)) threadLocalState.set(State.Loaded(model, context, batch, sampler))
} }
else -> throw IllegalStateException("Model already loaded") else -> throw IllegalStateException("Model already loaded")
} }
@ -114,7 +120,7 @@ class LLamaAndroid {
is State.Loaded -> { is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
while (ncur.value <= nlen) { while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, nlen, ncur) val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) { if (str == null) {
break break
} }
@ -138,6 +144,7 @@ class LLamaAndroid {
free_context(state.context) free_context(state.context)
free_model(state.model) free_model(state.model)
free_batch(state.batch) free_batch(state.batch)
free_sampler(state.sampler);
threadLocalState.set(State.Idle) threadLocalState.set(State.Idle)
} }
@ -161,7 +168,7 @@ class LLamaAndroid {
private sealed interface State { private sealed interface State {
data object Idle: State data object Idle: State
data class Loaded(val model: Long, val context: Long, val batch: Long): State data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
} }
// Enforce only one instance of Llm. // Enforce only one instance of Llm.