mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
llava : fix tokenization to not add bos between image embeddings and user prompt (#3645)
* llava : fix tokenization to not add bos after system prompt * set seed --------- Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com>
This commit is contained in:
parent
11bff29045
commit
940efa95fe
@ -49,9 +49,9 @@ inline bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
|
|||||||
return eval_tokens(ctx_llama, tokens, 1, n_past);
|
return eval_tokens(ctx_llama, tokens, 1, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){
|
inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
|
||||||
std::string str2 = str;
|
std::string str2 = str;
|
||||||
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true);
|
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos);
|
||||||
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -97,6 +97,7 @@ int main(int argc, char ** argv) {
|
|||||||
ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
|
ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
|
||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
ctx_params.seed = params.seed;
|
||||||
|
|
||||||
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
@ -106,7 +107,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||||
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
const int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
if (n_img_embd != n_llama_embd) {
|
if (n_img_embd != n_llama_embd) {
|
||||||
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
|
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
|
||||||
|
|
||||||
@ -125,14 +127,14 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||||
|
|
||||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
|
||||||
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
|
|
||||||
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
eval_string(ctx_llama, (params.prompt + "\nASSISTANT:").c_str(), params.n_batch, &n_past, false);
|
||||||
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
|
||||||
|
|
||||||
// generate the response
|
// generate the response
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
printf("prompt: '%s'\n", params.prompt.c_str());
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
for (int i = 0; i < max_tgt_len; i++) {
|
for (int i = 0; i < max_tgt_len; i++) {
|
||||||
|
Loading…
Reference in New Issue
Block a user