mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
This commit is contained in:
parent
2ddc9bbef1
commit
8ad92dc1ec
20
ggml-cuda.cu
20
ggml-cuda.cu
@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
||||||
@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||||||
if (need_check && col_data + 0 >= ncols_data) {
|
if (need_check && col_data + 0 >= ncols_data) {
|
||||||
val.x = -INFINITY;
|
val.x = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f);
|
||||||
}
|
}
|
||||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||||
val.y = -INFINITY;
|
val.y = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f);
|
||||||
}
|
}
|
||||||
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
||||||
vals[col_smem] = val;
|
vals[col_smem] = val;
|
||||||
@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||||||
const int ix = rowx*ncols + col;
|
const int ix = rowx*ncols + col;
|
||||||
const int iy = rowy*ncols + col;
|
const int iy = rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f);
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
}
|
}
|
||||||
@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max(
|
|||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max(
|
|||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
|
||||||
if (use_f16_soft_max) {
|
if (use_f16_soft_max) {
|
||||||
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
|
@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||||
|
|
||||||
|
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
||||||
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||||
|
|
||||||
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
||||||
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
||||||
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
||||||
|
@ -349,9 +349,9 @@ kernel void kernel_sum_rows(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const char * src0,
|
||||||
device const float * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -366,9 +366,9 @@ kernel void kernel_soft_max(
|
|||||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = -INFINITY;
|
float lmax = -INFINITY;
|
||||||
@ -435,9 +435,9 @@ kernel void kernel_soft_max(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const float * src0,
|
device const char * src0,
|
||||||
device const float * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -452,15 +452,15 @@ kernel void kernel_soft_max_4(
|
|||||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
@ -486,7 +486,7 @@ kernel void kernel_soft_max_4(
|
|||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
simdgroup_float8x8 mscale(scale);
|
simdgroup_half8x8 mscale(scale);
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
for (int64_t j = 0; j < Q8; ++j) {
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
simdgroup_float8x8 mm;
|
simdgroup_half8x8 mm;
|
||||||
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false);
|
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||||
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
||||||
|
|
||||||
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
||||||
|
13
ggml.c
13
ggml.c
@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
bool inplace) {
|
bool inplace) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(a));
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
if (mask) {
|
if (mask) {
|
||||||
|
GGML_ASSERT(mask->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
|
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||||
|
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
// broadcast the mask across rows
|
// broadcast the mask across rows
|
||||||
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
||||||
|
|
||||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||||
ggml_vec_scale_f32(nc, wp, scale);
|
ggml_vec_scale_f32(nc, wp, scale);
|
||||||
if (mp) {
|
if (mp) {
|
||||||
ggml_vec_acc_f32(nc, wp, mp);
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
wp[i] += GGML_FP16_TO_FP32(mp[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL;
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||||
|
|
||||||
// k indices
|
// k indices
|
||||||
const int ik3 = iq3 / rk3;
|
const int ik3 = iq3 / rk3;
|
||||||
@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
// loop over n_kv and n_head_kv
|
// loop over n_kv and n_head_kv
|
||||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
const float mv = mp ? mp[ic] : 0.0f;
|
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||||
if (mv == -INFINITY) {
|
if (mv == -INFINITY) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
4
ggml.h
4
ggml.h
@ -1646,10 +1646,12 @@ extern "C" {
|
|||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
bool masked);
|
bool masked);
|
||||||
|
|
||||||
|
#define GGML_KQ_MASK_PAD 32
|
||||||
|
|
||||||
// q: [n_embd, n_batch, n_head, 1]
|
// q: [n_embd, n_batch, n_head, 1]
|
||||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||||
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
||||||
// mask: [n_kv, n_batch, 1, 1]
|
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||||
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
40
llama.cpp
40
llama.cpp
@ -4721,7 +4721,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -4905,7 +4905,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -5026,7 +5026,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -5148,7 +5148,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||||
@ -5245,7 +5245,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
if (do_rope_shift) {
|
if (do_rope_shift) {
|
||||||
@ -5448,7 +5448,7 @@ struct llm_build_context {
|
|||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
@ -5538,7 +5538,7 @@ struct llm_build_context {
|
|||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
inpL = llm_build_norm(ctx0, inpL, hparams,
|
inpL = llm_build_norm(ctx0, inpL, hparams,
|
||||||
@ -5631,7 +5631,7 @@ struct llm_build_context {
|
|||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
@ -5731,7 +5731,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -5854,7 +5854,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -5968,7 +5968,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -6089,7 +6089,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -6211,7 +6211,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -6318,7 +6318,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||||
@ -6416,7 +6416,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -6524,7 +6524,7 @@ struct llm_build_context {
|
|||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
@ -10250,7 +10250,10 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
const auto & hparams = model->hparams;
|
const auto & hparams = model->hparams;
|
||||||
auto & cparams = ctx->cparams;
|
auto & cparams = ctx->cparams;
|
||||||
|
|
||||||
cparams.n_batch = params.n_batch;
|
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||||
|
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||||
|
cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch);
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||||
@ -10430,6 +10433,9 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
|
|
||||||
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
||||||
|
|
||||||
|
// zero-out the input buffer to prevent NaNs in padded tensors
|
||||||
|
ggml_backend_buffer_clear(ctx->buf_input, 0);
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
||||||
ggml_backend_buffer_name(ctx->buf_input),
|
ggml_backend_buffer_name(ctx->buf_input),
|
||||||
ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
|
ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
|
||||||
|
@ -1101,7 +1101,7 @@ struct test_soft_max : public test_case {
|
|||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_tensor * b = nullptr;
|
ggml_tensor * b = nullptr;
|
||||||
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
|
if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); }
|
||||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1472,7 +1472,7 @@ struct test_flash_attn_ext : public test_case {
|
|||||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1506,7 +1506,7 @@ struct test_attn : public test_case {
|
|||||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed
|
||||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1);
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
@ -1793,7 +1793,7 @@ struct test_llama : public test_llm {
|
|||||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
|
||||||
|
|
||||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
@ -1915,7 +1915,7 @@ struct test_falcon : public test_llm {
|
|||||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
|
||||||
|
|
||||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
|
Loading…
Reference in New Issue
Block a user