correct ggml_backend_tensor_copy

This commit is contained in:
ngxson 2024-07-06 15:06:32 +02:00
parent 1b4ffbac47
commit b88ce0f892

View File

@ -18446,9 +18446,10 @@ int32_t llama_lora_adapter_apply(struct llama_context * lctx, struct llama_lora_
};
struct ggml_context * ctx0 = ggml_init(ctx0_params);
// map "merged.%s" name to model tensor
std::map<std::string, struct ggml_tensor *> output_map;
// apply lora for model tensors
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
std::vector<std::pair<struct ggml_tensor *, struct ggml_tensor *>> output_nodes;
auto apply_lora = [&](struct llama_lora_adapter * adapter, struct ggml_tensor * model_tensor) {
if (model_tensor == nullptr) {
return;
@ -18459,9 +18460,9 @@ int32_t llama_lora_adapter_apply(struct llama_context * lctx, struct llama_lora_
struct ggml_tensor * cur = ggml_mul_mat(ctx0, lora_w.a, lora_w.b);
cur = ggml_scale_inplace(ctx0, cur, adapter->scale);
cur = ggml_add(ctx0, cur, model_tensor);
ggml_format_name(cur, "%s.merged", name.c_str());
ggml_format_name(cur, "merged.%s", name.c_str());
ggml_build_forward_expand(gf, cur);
output_nodes.push_back({model_tensor, cur});
output_map[std::string(cur->name)] = model_tensor;
}
};
apply_lora(adapter, model.tok_embd);
@ -18543,13 +18544,19 @@ int32_t llama_lora_adapter_apply(struct llama_context * lctx, struct llama_lora_
// merge lora to model weight
ggml_status res = ggml_backend_sched_graph_compute(lctx->sched, gf);
if (res == GGML_STATUS_SUCCESS) {
for (auto & out : output_nodes) {
struct ggml_tensor * model_tensor = out.first;
struct ggml_tensor * merged_tensor = out.second;
ggml_backend_tensor_copy(merged_tensor, model_tensor);
ggml_set_name(model_tensor, merged_tensor->name);
// graph maybe realloc, we need to find correct gf->nodes based on name
size_t n_merged = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
auto node = gf->nodes[i];
std::string name(node->name);
if (output_map.find(name) != output_map.end()) {
struct ggml_tensor * model_tensor = output_map[name];
ggml_backend_tensor_copy(node, model_tensor);
n_merged++;
}
}
LLAMA_LOG_ERROR("%s: merged %ld lora weights to model\n", __func__, output_nodes.size());
GGML_ASSERT(n_merged == output_map.size());
LLAMA_LOG_ERROR("%s: merged %ld lora weights to model\n", __func__, n_merged);
} else {
LLAMA_LOG_ERROR("%s: compute error while merging lora weights to model, result = %d\n", __func__, res);
return res;