mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
speculative : fix handling of some input params (#9963)
* speculative : fix batch sizes at initialization ggml-ci * speculative : handle params.n_predict == -1 * speculative : limit batch size to llama_n_batch
This commit is contained in:
parent
1db8c84fc6
commit
bc21975084
@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.n_predict < -1) {
|
||||||
|
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
common_init();
|
common_init();
|
||||||
|
|
||||||
if (params.model_draft.empty()) {
|
if (params.model_draft.empty()) {
|
||||||
@ -190,8 +195,8 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
@ -441,7 +446,7 @@ int main(int argc, char ** argv) {
|
|||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user