mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
metal : add cpy f16 -> f32 kernel
This commit is contained in:
parent
a742d9f9b7
commit
08eb99179a
10
convert.py
10
convert.py
@ -63,10 +63,10 @@ class UnquantizedDataType(DataType):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
|
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
|
||||||
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
|
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
|
||||||
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
|
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
|
||||||
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
|
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -996,7 +996,7 @@ class OutputFile:
|
|||||||
|
|
||||||
|
|
||||||
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
|
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
|
||||||
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
|
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
|
||||||
|
|
||||||
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
|
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
|
||||||
return GGMLFileType.AllF32
|
return GGMLFileType.AllF32
|
||||||
|
36
ggml-metal.m
36
ggml-metal.m
@ -155,6 +155,7 @@ struct ggml_metal_context {
|
|||||||
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
||||||
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||||
|
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(concat);
|
GGML_METAL_DECL_KERNEL(concat);
|
||||||
GGML_METAL_DECL_KERNEL(sqr);
|
GGML_METAL_DECL_KERNEL(sqr);
|
||||||
GGML_METAL_DECL_KERNEL(sum_rows);
|
GGML_METAL_DECL_KERNEL(sum_rows);
|
||||||
@ -424,6 +425,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
||||||
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||||
|
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(concat);
|
GGML_METAL_ADD_KERNEL(concat);
|
||||||
GGML_METAL_ADD_KERNEL(sqr);
|
GGML_METAL_ADD_KERNEL(sqr);
|
||||||
GGML_METAL_ADD_KERNEL(sum_rows);
|
GGML_METAL_ADD_KERNEL(sum_rows);
|
||||||
@ -539,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
||||||
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||||
|
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(concat);
|
GGML_METAL_DEL_KERNEL(concat);
|
||||||
GGML_METAL_DEL_KERNEL(sqr);
|
GGML_METAL_DEL_KERNEL(sqr);
|
||||||
GGML_METAL_DEL_KERNEL(sum_rows);
|
GGML_METAL_DEL_KERNEL(sum_rows);
|
||||||
@ -867,12 +870,37 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_DUP:
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
case GGML_OP_CONT:
|
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_DUP:
|
||||||
|
case GGML_OP_CONT:
|
||||||
|
{
|
||||||
|
switch (op->src[0]->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
switch (op->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
switch (op->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
}
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
return op->ne[0] % 4 == 0;
|
return op->ne[0] % 4 == 0;
|
||||||
@ -2021,7 +2049,7 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
@ -1698,8 +1698,8 @@ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_ar
|
|||||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
@ -1738,6 +1738,47 @@ kernel void kernel_cpy_f16_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_cpy_f16_f32(
|
||||||
|
device const half * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||||
|
|
||||||
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||||
|
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
dst_data[i00] = src[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_cpy_f32_f16(
|
kernel void kernel_cpy_f32_f16(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
|
@ -4277,23 +4277,23 @@ struct llm_build_context {
|
|||||||
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||||
cb(logits, "ffn_moe_logits", il);
|
cb(logits, "ffn_moe_logits", il);
|
||||||
|
|
||||||
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
||||||
cb(probs, "ffn_moe_probs", il);
|
cb(probs, "ffn_moe_probs", il);
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
|
|
||||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||||
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
||||||
cb(weights, "ffn_moe_weights", il);
|
cb(weights, "ffn_moe_weights", il);
|
||||||
|
|
||||||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||||
|
|
||||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
||||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||||
|
|
||||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
||||||
cb(weights, "ffn_moe_weights_norm", il);
|
cb(weights, "ffn_moe_weights_norm", il);
|
||||||
|
|
||||||
// compute expert outputs
|
// compute expert outputs
|
||||||
|
Loading…
Reference in New Issue
Block a user