simple : do not perform tensor data copy if not needed

This commit is contained in:
Georgi Gerganov 2024-01-15 16:42:16 +02:00
parent 83f3d7a83c
commit e1b1db9f09
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -22,12 +22,20 @@ static bool observe_compute(struct ggml_tensor * t, bool ask, void * user_data)
__func__, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); __func__, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
// this will copy the data to host memory (if needed) // this will copy the data to host memory (if needed)
std::vector<float> t_data(ggml_nelements(t)); static std::vector<float> t_data;
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
const bool is_host = ggml_backend_buffer_is_host(t->buffer);
if (!is_host || ggml_is_contiguous(t)) {
t_data.resize(ggml_nelements(t));
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
}
const float * data = is_host ? (const float *) t->data : t_data.data();
// print first row // print first row
for (int i = 0; i < t->ne[0]; i++) { for (int i = 0; i < t->ne[0]; i++) {
printf("%8.4f ", t_data[i]); printf("%8.4f ", data[i]);
} }
printf("\n"); printf("\n");