mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
feat: remove a sampler from a chain (#9445)
* feat: remove a sampler from a chain * fix: return removed sampler * fix: safer casting
This commit is contained in:
parent
78203641fe
commit
bd35cb0ae3
@ -1056,6 +1056,9 @@ extern "C" {
|
|||||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
||||||
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||||
|
|
||||||
|
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
|
||||||
|
|
||||||
// available samplers:
|
// available samplers:
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
|
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
|
||||||
|
@ -349,13 +349,26 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
|
|||||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||||
|
|
||||||
if (i < 0 || i >= (int32_t) p->samplers.size()) {
|
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return p->samplers[i];
|
return p->samplers[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
||||||
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
||||||
|
|
||||||
|
if (i < 0 || (size_t) i >= p->samplers.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto * result = p->samplers[i];
|
||||||
|
p->samplers.erase(p->samplers.begin() + i);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
||||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user