mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
metal : fix ggml_get_rows to work with non-cont src1
This commit is contained in:
parent
0710b0f726
commit
016f9bb55a
@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
||||
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
|
@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
constant uint64_t & nb2,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
//const int64_t i = tgpig;
|
||||
//const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
|
||||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
||||
float4x4 temp;
|
||||
dequantize_func(
|
||||
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
||||
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
|
||||
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
||||
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f32(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
constant uint64_t & nb2,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f16(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
constant uint64_t & nb2,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
typedef void (get_rows_t)(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
uint, uint, uint);
|
||||
constant uint64_t & nb2,
|
||||
uint3, uint, uint3);
|
||||
|
||||
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
|
Loading…
Reference in New Issue
Block a user