mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
ggml : reuse ggml_get_n_tasks() in ggml_graph_plan() (#4308)
* ggml : fix soft max out-of-bounds access ggml-ci * ggml : reuse ggml_get_n_tasks() in ggml_graph_plan() ggml-ci
This commit is contained in:
parent
adf3de4f69
commit
fbbc42827b
23
ggml.c
23
ggml.c
@ -15879,18 +15879,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
|
|
||||||
// thread scheduling for the different operations + work buffer size estimation
|
// thread scheduling for the different operations + work buffer size estimation
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
int n_tasks = 1;
|
|
||||||
|
|
||||||
struct ggml_tensor * node = cgraph->nodes[i];
|
struct ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
|
const int n_tasks = ggml_get_n_tasks(node, n_threads);
|
||||||
|
|
||||||
size_t cur = 0;
|
size_t cur = 0;
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
if (ggml_is_quantized(node->type)) {
|
if (ggml_is_quantized(node->type)) {
|
||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||||
}
|
}
|
||||||
@ -15898,16 +15896,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
if (ggml_is_quantized(node->src[0]->type)) {
|
if (ggml_is_quantized(node->src[0]->type)) {
|
||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
if (ggml_is_quantized(node->src[0]->type)) {
|
if (ggml_is_quantized(node->src[0]->type)) {
|
||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
|
||||||
}
|
}
|
||||||
@ -15935,16 +15929,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
if (ggml_is_quantized(node->src[0]->type)) {
|
if (ggml_is_quantized(node->src[0]->type)) {
|
||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
|
|
||||||
|
|
||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
@ -15974,7 +15964,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
{
|
{
|
||||||
@ -15992,8 +15981,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||||
|
|
||||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
if (node->src[1]->type == GGML_TYPE_F32) {
|
||||||
@ -16006,8 +15993,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
case GGML_OP_FLASH_FF:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
if (node->src[1]->type == GGML_TYPE_F32) {
|
||||||
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
||||||
@ -16018,8 +16003,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
const int64_t D = node->src[0]->ne[0];
|
const int64_t D = node->src[0]->ne[0];
|
||||||
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||||
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
||||||
@ -16034,8 +16017,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
|
||||||
|
|
||||||
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_COUNT:
|
case GGML_OP_COUNT:
|
||||||
|
Loading…
Reference in New Issue
Block a user