diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 921793751..06ec160c2 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( return env->NewStringUTF(result.str().c_str()); } -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - llama_batch_free(*reinterpret_cast(batch_pointer)); -} - extern "C" JNIEXPORT jlong JNICALL Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { @@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, return reinterpret_cast(batch); } +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { + llama_batch_free(*reinterpret_cast(batch_pointer)); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) { + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = true; + llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + + return reinterpret_cast(smpl); +} + +extern "C" +JNIEXPORT void JNICALL +Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) { + llama_sampler_free(reinterpret_cast(sampler_pointer)); +} + extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { @@ -380,14 +397,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( JNIEnv * env, jobject, jlong context_pointer, - jlong sampling_pointer, jlong batch_pointer, + jlong sampler_pointer, jint n_len, jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); - const auto sampling = reinterpret_cast(sampling_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); + const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); @@ -395,9 +412,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); // sample the most likely token - const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1); + const auto new_token_id = llama_sampler_sample(sampler, context, -1); - llama_sampler_accept(sampling, new_token_id); + llama_sampler_accept(sampler, new_token_id); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 6c63e54e0..cf520e459 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -45,8 +45,10 @@ class LLamaAndroid { private external fun free_context(context: Long) private external fun backend_init(numa: Boolean) private external fun backend_free() - private external fun free_batch(batch: Long) private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long + private external fun free_batch(batch: Long) + private external fun new_sampler(): Long + private external fun free_sampler(sampler: Long) private external fun bench_model( context: Long, model: Long, @@ -69,6 +71,7 @@ class LLamaAndroid { private external fun completion_loop( context: Long, batch: Long, + sampler: Long, nLen: Int, ncur: IntVar ): String? @@ -101,8 +104,11 @@ class LLamaAndroid { val batch = new_batch(512, 0, 1) if (batch == 0L) throw IllegalStateException("new_batch() failed") + val sampler = new_sampler() + if (sampler == 0L) throw IllegalStateException("new_sampler() failed") + Log.i(tag, "Loaded model $pathToModel") - threadLocalState.set(State.Loaded(model, context, batch)) + threadLocalState.set(State.Loaded(model, context, batch, sampler)) } else -> throw IllegalStateException("Model already loaded") } @@ -114,7 +120,7 @@ class LLamaAndroid { is State.Loaded -> { val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) while (ncur.value <= nlen) { - val str = completion_loop(state.context, state.batch, nlen, ncur) + val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur) if (str == null) { break } @@ -138,6 +144,7 @@ class LLamaAndroid { free_context(state.context) free_model(state.model) free_batch(state.batch) + free_sampler(state.sampler); threadLocalState.set(State.Idle) } @@ -161,7 +168,7 @@ class LLamaAndroid { private sealed interface State { data object Idle: State - data class Loaded(val model: Long, val context: Long, val batch: Long): State + data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State } // Enforce only one instance of Llm.