mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-16 15:18:26 +01:00
ggml : implement soft_max_ext (CPU)
This commit is contained in:
parent
88519fbf97
commit
6a66f69f9f
44
ggml.c
44
ggml.c
@ -4829,7 +4829,9 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
if (mask) {
|
if (mask) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
||||||
@ -10571,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
|
|||||||
static void ggml_compute_forward_soft_max_f32(
|
static void ggml_compute_forward_soft_max_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
assert(ggml_is_contiguous(dst));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
assert(ggml_are_same_shape(src0, dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
|
||||||
// TODO: handle transposed/permuted matrices
|
// TODO: handle transposed/permuted matrices
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
const int nc = src0->ne[0];
|
||||||
const int nr = ggml_nrows(src0);
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
@ -10595,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
const int ir0 = dr*ith;
|
const int ir0 = dr*ith;
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
|
// broadcast the mask across rows
|
||||||
|
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
||||||
|
|
||||||
|
float * wp = wdata;
|
||||||
|
for (int i = 0; i < nc; i++) {
|
||||||
|
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
//printf("p[%d] = %f\n", i, p[i]);
|
//printf("p[%d] = %f\n", i, p[i]);
|
||||||
assert(!isnan(sp[i]));
|
assert(!isnan(wp[i]));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(nc, &max, sp);
|
ggml_vec_max_f32(nc, &max, wp);
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
|
|
||||||
uint16_t scvt;
|
uint16_t scvt;
|
||||||
for (int i = 0; i < nc; i++) {
|
for (int i = 0; i < nc; i++) {
|
||||||
if (sp[i] == -INFINITY) {
|
if (wp[i] == -INFINITY) {
|
||||||
dp[i] = 0.0f;
|
dp[i] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
// const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
|
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
|
ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
|
||||||
memcpy(&scvt, &s, sizeof(scvt));
|
memcpy(&scvt, &s, sizeof(scvt));
|
||||||
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
||||||
sum += (ggml_float)val;
|
sum += (ggml_float)val;
|
||||||
@ -10642,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
static void ggml_compute_forward_soft_max(
|
static void ggml_compute_forward_soft_max(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@ -13883,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
|
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
{
|
{
|
||||||
@ -15919,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
{
|
||||||
|
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
|
||||||
|
|
||||||
|
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||||
|
} break;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user