mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
llama : refactor internal quantization functions (#5830)
This commit is contained in:
parent
802da0091b
commit
6c32d8c7ad
81
llama.cpp
81
llama.cpp
@ -10836,7 +10836,7 @@ struct quantize_state_internal {
|
|||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void llama_convert_tensor_internal(
|
static void llama_tensor_dequantize_internal(
|
||||||
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||||
const size_t nelements, const int nthread
|
const size_t nelements, const int nthread
|
||||||
) {
|
) {
|
||||||
@ -11177,6 +11177,46 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
|||||||
return new_type;
|
return new_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int32_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, int64_t * hist_cur, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
|
||||||
|
std::mutex mutex;
|
||||||
|
int counter = 0;
|
||||||
|
size_t new_size = 0;
|
||||||
|
if (nthread < 2) {
|
||||||
|
// single-thread
|
||||||
|
return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, hist_cur, imatrix);
|
||||||
|
}
|
||||||
|
auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, chunk_size,
|
||||||
|
nrows, n_per_row, imatrix]() {
|
||||||
|
std::array<int64_t, 1 << 4> local_hist = {};
|
||||||
|
const int nrows_per_chunk = chunk_size / n_per_row;
|
||||||
|
size_t local_size = 0;
|
||||||
|
while (true) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
int first_row = counter; counter += nrows_per_chunk;
|
||||||
|
if (first_row >= nrows) {
|
||||||
|
if (local_size > 0) {
|
||||||
|
for (int j=0; j<int(local_hist.size()); ++j) {
|
||||||
|
hist_cur[j] += local_hist[j];
|
||||||
|
}
|
||||||
|
new_size += local_size;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
lock.unlock();
|
||||||
|
const int this_nrow = std::min(nrows - first_row, nrows_per_chunk);
|
||||||
|
local_size += ggml_quantize_chunk(new_type, f32_data, new_data,
|
||||||
|
first_row * n_per_row, this_nrow, n_per_row, local_hist.data(), imatrix);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (int it = 0; it < nthread - 1; ++it) {
|
||||||
|
workers.emplace_back(compute);
|
||||||
|
}
|
||||||
|
compute();
|
||||||
|
for (auto & w : workers) { w.join(); }
|
||||||
|
workers.clear();
|
||||||
|
return new_size;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
||||||
ggml_type quantized_type;
|
ggml_type quantized_type;
|
||||||
llama_ftype ftype = params->ftype;
|
llama_ftype ftype = params->ftype;
|
||||||
@ -11289,7 +11329,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
workers.reserve(nthread);
|
workers.reserve(nthread);
|
||||||
std::mutex mutex;
|
|
||||||
|
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
|
|
||||||
@ -11403,7 +11442,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||||||
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
|
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
|
||||||
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
|
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
|
||||||
} else {
|
} else {
|
||||||
llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread);
|
llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread);
|
||||||
f32_data = (float *) f32_conv_buf.data();
|
f32_data = (float *) f32_conv_buf.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -11424,41 +11463,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||||||
|
|
||||||
const int nchunk = (nelements + chunk_size - 1)/chunk_size;
|
const int nchunk = (nelements + chunk_size - 1)/chunk_size;
|
||||||
const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
|
const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
|
||||||
if (nthread_use < 2) {
|
new_size = llama_tensor_quantize_internal(new_type, f32_data, new_data, chunk_size, nrows, n_per_row, hist_cur.data(), imatrix, workers, nthread_use);
|
||||||
new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, hist_cur.data(), imatrix);
|
|
||||||
} else {
|
|
||||||
int counter = 0;
|
|
||||||
new_size = 0;
|
|
||||||
auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, chunk_size,
|
|
||||||
nrows, n_per_row, imatrix]() {
|
|
||||||
std::array<int64_t, 1 << 4> local_hist = {};
|
|
||||||
const int nrows_per_chunk = chunk_size / n_per_row;
|
|
||||||
size_t local_size = 0;
|
|
||||||
while (true) {
|
|
||||||
std::unique_lock<std::mutex> lock(mutex);
|
|
||||||
int first_row = counter; counter += nrows_per_chunk;
|
|
||||||
if (first_row >= nrows) {
|
|
||||||
if (local_size > 0) {
|
|
||||||
for (int j=0; j<int(local_hist.size()); ++j) {
|
|
||||||
hist_cur[j] += local_hist[j];
|
|
||||||
}
|
|
||||||
new_size += local_size;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
lock.unlock();
|
|
||||||
const int this_nrow = std::min(nrows - first_row, nrows_per_chunk);
|
|
||||||
local_size += ggml_quantize_chunk(new_type, f32_data, new_data,
|
|
||||||
first_row * n_per_row, this_nrow, n_per_row, local_hist.data(), imatrix);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (int it = 0; it < nthread_use - 1; ++it) {
|
|
||||||
workers.emplace_back(compute);
|
|
||||||
}
|
|
||||||
compute();
|
|
||||||
for (auto & w : workers) { w.join(); }
|
|
||||||
workers.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
||||||
int64_t tot_count = 0;
|
int64_t tot_count = 0;
|
||||||
|
Loading…
Reference in New Issue
Block a user