cuda : fix LLAMA_CUDA_F16 (#5262)

This commit is contained in:
slaren 2024-02-01 18:30:17 +01:00 committed by GitHub
parent d71ac90985
commit 8ca511cade
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8657,9 +8657,9 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
if (src1_convert_f16) { if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00); src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
ne00, 1, sizeof(float), 0, 0, GGML_ASSERT(to_fp16_cuda != nullptr);
ne00, 1, sizeof(half), 0, 0, stream); to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
} }
#else #else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion