mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
llama.android : fix build (#9350)
This commit is contained in:
parent
f12295b8a9
commit
a5b5d9a101
@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||||||
return env->NewStringUTF(result.str().c_str());
|
return env->NewStringUTF(result.str().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
|
||||||
JNIEXPORT void JNICALL
|
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
|
||||||
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jlong JNICALL
|
JNIEXPORT jlong JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
||||||
@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|||||||
return reinterpret_cast<jlong>(batch);
|
return reinterpret_cast<jlong>(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
||||||
|
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jlong JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
|
||||||
|
auto sparams = llama_sampler_chain_default_params();
|
||||||
|
sparams.no_perf = true;
|
||||||
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||||
|
|
||||||
|
return reinterpret_cast<jlong>(smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
|
||||||
|
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
|
||||||
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
|
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
|
||||||
@ -380,14 +397,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||||||
JNIEnv * env,
|
JNIEnv * env,
|
||||||
jobject,
|
jobject,
|
||||||
jlong context_pointer,
|
jlong context_pointer,
|
||||||
jlong sampling_pointer,
|
|
||||||
jlong batch_pointer,
|
jlong batch_pointer,
|
||||||
|
jlong sampler_pointer,
|
||||||
jint n_len,
|
jint n_len,
|
||||||
jobject intvar_ncur
|
jobject intvar_ncur
|
||||||
) {
|
) {
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
||||||
const auto model = llama_get_model(context);
|
const auto model = llama_get_model(context);
|
||||||
|
|
||||||
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
|
||||||
@ -395,9 +412,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||||||
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
||||||
|
|
||||||
// sample the most likely token
|
// sample the most likely token
|
||||||
const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
|
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
|
||||||
|
|
||||||
llama_sampler_accept(sampling, new_token_id);
|
llama_sampler_accept(sampler, new_token_id);
|
||||||
|
|
||||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||||
|
@ -45,8 +45,10 @@ class LLamaAndroid {
|
|||||||
private external fun free_context(context: Long)
|
private external fun free_context(context: Long)
|
||||||
private external fun backend_init(numa: Boolean)
|
private external fun backend_init(numa: Boolean)
|
||||||
private external fun backend_free()
|
private external fun backend_free()
|
||||||
private external fun free_batch(batch: Long)
|
|
||||||
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
|
||||||
|
private external fun free_batch(batch: Long)
|
||||||
|
private external fun new_sampler(): Long
|
||||||
|
private external fun free_sampler(sampler: Long)
|
||||||
private external fun bench_model(
|
private external fun bench_model(
|
||||||
context: Long,
|
context: Long,
|
||||||
model: Long,
|
model: Long,
|
||||||
@ -69,6 +71,7 @@ class LLamaAndroid {
|
|||||||
private external fun completion_loop(
|
private external fun completion_loop(
|
||||||
context: Long,
|
context: Long,
|
||||||
batch: Long,
|
batch: Long,
|
||||||
|
sampler: Long,
|
||||||
nLen: Int,
|
nLen: Int,
|
||||||
ncur: IntVar
|
ncur: IntVar
|
||||||
): String?
|
): String?
|
||||||
@ -101,8 +104,11 @@ class LLamaAndroid {
|
|||||||
val batch = new_batch(512, 0, 1)
|
val batch = new_batch(512, 0, 1)
|
||||||
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
if (batch == 0L) throw IllegalStateException("new_batch() failed")
|
||||||
|
|
||||||
|
val sampler = new_sampler()
|
||||||
|
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
|
||||||
|
|
||||||
Log.i(tag, "Loaded model $pathToModel")
|
Log.i(tag, "Loaded model $pathToModel")
|
||||||
threadLocalState.set(State.Loaded(model, context, batch))
|
threadLocalState.set(State.Loaded(model, context, batch, sampler))
|
||||||
}
|
}
|
||||||
else -> throw IllegalStateException("Model already loaded")
|
else -> throw IllegalStateException("Model already loaded")
|
||||||
}
|
}
|
||||||
@ -114,7 +120,7 @@ class LLamaAndroid {
|
|||||||
is State.Loaded -> {
|
is State.Loaded -> {
|
||||||
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
|
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
|
||||||
while (ncur.value <= nlen) {
|
while (ncur.value <= nlen) {
|
||||||
val str = completion_loop(state.context, state.batch, nlen, ncur)
|
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
|
||||||
if (str == null) {
|
if (str == null) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -138,6 +144,7 @@ class LLamaAndroid {
|
|||||||
free_context(state.context)
|
free_context(state.context)
|
||||||
free_model(state.model)
|
free_model(state.model)
|
||||||
free_batch(state.batch)
|
free_batch(state.batch)
|
||||||
|
free_sampler(state.sampler);
|
||||||
|
|
||||||
threadLocalState.set(State.Idle)
|
threadLocalState.set(State.Idle)
|
||||||
}
|
}
|
||||||
@ -161,7 +168,7 @@ class LLamaAndroid {
|
|||||||
|
|
||||||
private sealed interface State {
|
private sealed interface State {
|
||||||
data object Idle: State
|
data object Idle: State
|
||||||
data class Loaded(val model: Long, val context: Long, val batch: Long): State
|
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce only one instance of Llm.
|
// Enforce only one instance of Llm.
|
||||||
|
Loading…
Reference in New Issue
Block a user