mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
metal : fix ggml_get_rows to work with non-cont src1
This commit is contained in:
parent
0710b0f726
commit
016f9bb55a
@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
||||||
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
||||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
|
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
||||||
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(src1);
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
|
@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||||
kernel void kernel_get_rows(
|
kernel void kernel_get_rows(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
constant uint64_t & nb2,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint tptg [[threads_per_threadgroup]]) {
|
uint3 tptg [[threads_per_threadgroup]]) {
|
||||||
const int64_t i = tgpig;
|
//const int64_t i = tgpig;
|
||||||
const int64_t r = ((device int32_t *) src1)[i];
|
//const int64_t r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
|
const int64_t i10 = tgpig.x;
|
||||||
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
|
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
|
const int64_t i02 = i11;
|
||||||
|
|
||||||
|
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
||||||
float4x4 temp;
|
float4x4 temp;
|
||||||
dequantize_func(
|
dequantize_func(
|
||||||
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
||||||
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
|
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_f32(
|
kernel void kernel_get_rows_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
constant uint64_t & nb2,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint tptg [[threads_per_threadgroup]]) {
|
uint3 tptg [[threads_per_threadgroup]]) {
|
||||||
const int64_t i = tgpig;
|
const int64_t i10 = tgpig.x;
|
||||||
const int64_t r = ((device int32_t *) src1)[i];
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
const int64_t i02 = i/ne10;
|
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
const int64_t i02 = i11;
|
||||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
|
||||||
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||||
|
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||||
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_f16(
|
kernel void kernel_get_rows_f16(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
constant uint64_t & nb2,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint tptg [[threads_per_threadgroup]]) {
|
uint3 tptg [[threads_per_threadgroup]]) {
|
||||||
const int64_t i = tgpig;
|
const int64_t i10 = tgpig.x;
|
||||||
const int64_t r = ((device int32_t *) src1)[i];
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
const int64_t i02 = i/ne10;
|
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
const int64_t i02 = i11;
|
||||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
|
||||||
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||||
|
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||||
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id(
|
|||||||
|
|
||||||
typedef void (get_rows_t)(
|
typedef void (get_rows_t)(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
uint, uint, uint);
|
constant uint64_t & nb2,
|
||||||
|
uint3, uint, uint3);
|
||||||
|
|
||||||
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||||
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||||
|
Loading…
Reference in New Issue
Block a user