diff --git a/ci/run.sh b/ci/run.sh index f3a8ff774..791b17a19 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -36,6 +36,10 @@ if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DLLAMA_METAL_SHADER_DEBUG=ON" fi +if [ ! -z ${GG_BUILD_CUDA} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DLLAMA_CUBLAS=1" +fi + ## helpers # download a file if it does not exist or if it is outdated @@ -160,8 +164,8 @@ function gg_run_open_llama_3b_v2 { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release -DLLAMA_QKK_64=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DLLAMA_QKK_64=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert.py ${path_models} @@ -343,8 +347,8 @@ function gg_run_open_llama_7b_v2 { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release -DLLAMA_CUBLAS=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DLLAMA_CUBLAS=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert.py ${path_models} diff --git a/ggml-backend.c b/ggml-backend.c index 4266250f9..ef518dae0 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -692,6 +692,8 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { switch (op->op) { + case GGML_OP_CPY: + return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float case GGML_OP_MUL_MAT: return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 568c411af..b2211d858 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5131,10 +5131,10 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index + for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index - const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int @@ -10918,6 +10918,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (a->ne[3] != b->ne[3]) { return false; } + ggml_type a_type = a->type; + if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) { + if (b->ne[1] == 1 && ggml_nrows(b) > 1) { + return false; + } + } return true; } break; case GGML_OP_GET_ROWS: diff --git a/ggml-quants.c b/ggml-quants.c index 31b053e33..7d2f033e9 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1274,7 +1274,12 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * } float sumlx = 0; float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else for (int i = 0; i < n; ++i) { +#endif int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); L[i] = l + nmax; @@ -1649,7 +1654,12 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f float max = x[0]; float sum_w = weights ? weights[0] : x[0]*x[0]; float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else for (int i = 1; i < n; ++i) { +#endif if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; float w = weights ? weights[i] : x[i]*x[i]; @@ -1660,7 +1670,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f min = 0; } if (max <= min) { - for (int i = 0; i < n; ++i) L[i] = 0; + memset(L, 0, n); *the_min = -min; return 0.f; } @@ -1862,7 +1872,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { quantize_row_q2_K_reference(src, dst, nrow*n_per_row); } @@ -2181,7 +2191,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { quantize_row_q3_K_reference(src, dst, nrow*n_per_row); } @@ -2448,7 +2458,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { quantize_row_q4_K_reference(src, dst, nrow*n_per_row); } @@ -2771,7 +2781,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { quantize_row_q5_K_reference(src, dst, nrow*n_per_row); } @@ -3025,7 +3035,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { quantize_row_q6_K_reference(src, dst, nrow*n_per_row); } @@ -3072,7 +3082,7 @@ size_t quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); @@ -3116,7 +3126,7 @@ size_t quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); @@ -3169,7 +3179,7 @@ size_t quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); @@ -3221,7 +3231,7 @@ size_t quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); @@ -8565,7 +8575,7 @@ static int iq2_compare_func(const void * left, const void * right) { return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; } -static void q2xs_init_impl(int grid_size) { +void iq2xs_init_impl(int grid_size) { const int gindex = iq2_data_index(grid_size); if (iq2_data[gindex].grid) { return; @@ -8720,19 +8730,7 @@ static void q2xs_init_impl(int grid_size) { free(dist2); } -void ggml_init_iq2_quantization(enum ggml_type type) { - if (type == GGML_TYPE_IQ2_XXS) { - q2xs_init_impl(256); - } - else if (type == GGML_TYPE_IQ2_XS) { - q2xs_init_impl(512); - } - else { - fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type); - } -} - -static void q2xs_deinit_impl(int grid_size) { +void iq2xs_free_impl(int grid_size) { GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024); const int gindex = iq2_data_index(grid_size); if (iq2_data[gindex].grid) { @@ -8742,18 +8740,6 @@ static void q2xs_deinit_impl(int grid_size) { } } -void ggml_deinit_iq2_quantization(enum ggml_type type) { - if (type == GGML_TYPE_IQ2_XXS) { - q2xs_deinit_impl(256); - } - else if (type == GGML_TYPE_IQ2_XS) { - q2xs_deinit_impl(512); - } - else { - fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type); - } -} - static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { int num_neighbors = neighbours[0]; @@ -8786,10 +8772,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict const int * kmap_q2xs = iq2_data[gindex].map; const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - GGML_ASSERT(quant_weights); - GGML_ASSERT(kgrid_q2xs); - GGML_ASSERT(kmap_q2xs); - GGML_ASSERT(kneighbors_q2xs); + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); GGML_ASSERT(n%QK_K == 0); const int kMaxQ = 3; @@ -9005,10 +8991,10 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v const int * kmap_q2xs = iq2_data[gindex].map; const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - GGML_ASSERT(quant_weights); - GGML_ASSERT(kmap_q2xs); - GGML_ASSERT(kgrid_q2xs); - GGML_ASSERT(kneighbors_q2xs); + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); GGML_ASSERT(n%QK_K == 0); const int kMaxQ = 3; diff --git a/ggml-quants.h b/ggml-quants.h index d7fefdb54..7d7cf9178 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -257,3 +257,6 @@ size_t quantize_q4_0 (const float * src, void * dst, int nrows, int n_per_row, size_t quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); + +void iq2xs_init_impl(int grid_size); +void iq2xs_free_impl(int grid_size); diff --git a/ggml.c b/ggml.c index 35fd29a9e..cbf2d4bdd 100644 --- a/ggml.c +++ b/ggml.c @@ -18524,6 +18524,28 @@ enum ggml_opt_result ggml_opt_resume_g( //////////////////////////////////////////////////////////////////////////////// +void ggml_quantize_init(enum ggml_type type) { + ggml_critical_section_start(); + + switch (type) { + case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break; + case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break; + default: // nothing + break; + } + + ggml_critical_section_end(); +} + +void ggml_quantize_free(void) { + ggml_critical_section_start(); + + iq2xs_free_impl(256); + iq2xs_free_impl(512); + + ggml_critical_section_end(); +} + size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK4_0 == 0); const int nb = k / QK4_0; @@ -18651,9 +18673,15 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * return (n/QK8_0*sizeof(block_q8_0)); } +bool ggml_quantize_requires_imatrix(enum ggml_type type) { + return + type == GGML_TYPE_IQ2_XXS || + type == GGML_TYPE_IQ2_XS; +} + size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix) { - (void)imatrix; + ggml_quantize_init(type); // this is noop if already initialized size_t result = 0; int n = nrows * n_per_row; switch (type) { @@ -18766,13 +18794,13 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i } break; case GGML_TYPE_F16: { - int elemsize = sizeof(ggml_fp16_t); + size_t elemsize = sizeof(ggml_fp16_t); ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); result = n * elemsize; } break; case GGML_TYPE_F32: { - int elemsize = sizeof(float); + size_t elemsize = sizeof(float); result = n * elemsize; memcpy((uint8_t *)dst + start * elemsize, src + start, result); } break; diff --git a/ggml.h b/ggml.h index 27daf6fd1..de8162b81 100644 --- a/ggml.h +++ b/ggml.h @@ -2065,6 +2065,18 @@ extern "C" { // quantization // + // - ggml_quantize_init can be called multiple times with the same type + // it will only initialize the quantization tables for the first call or after ggml_quantize_free + // automatically called by ggml_quantize_chunk for convenience + // + // - ggml_quantize_free will free any memory allocated by ggml_quantize_init + // call this at the end of the program to avoid memory leaks + // + // note: these are thread-safe + // + GGML_API void ggml_quantize_init(enum ggml_type type); + GGML_API void ggml_quantize_free(void); + // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); @@ -2078,13 +2090,13 @@ extern "C" { GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + // some quantization type cannot be used without an importance matrix + GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type); + + // calls ggml_quantize_init internally (i.e. can allocate memory) GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix); - // These are needed for IQ2_XS and IQ2_XXS quantizations - GGML_API void ggml_init_iq2_quantization(enum ggml_type type); - GGML_API void ggml_deinit_iq2_quantization(enum ggml_type type); - // // gguf // diff --git a/llama.cpp b/llama.cpp index 81829b13e..d28382f7d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8747,8 +8747,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // placeholder for the meta data ::zeros(fout, meta_size); - std::set used_iq2; - for (int i = 0; i < ml.n_tensors; ++i) { struct ggml_tensor * tensor = ml.get_tensor_meta(i); @@ -8801,11 +8799,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { const size_t nelements = ggml_nelements(tensor); - if ((new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS) && used_iq2.find(new_type) == used_iq2.end()) { - ggml_init_iq2_quantization(new_type); - used_iq2.insert(new_type); - } - const float * imatrix = nullptr; if (imatrix_data) { auto it = imatrix_data->find(tensor->name); @@ -8931,10 +8924,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s fout.close(); - for (auto type : used_iq2) { - ggml_deinit_iq2_quantization(type); - } - gguf_free(ctx_out); LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); @@ -9342,6 +9331,7 @@ void llama_backend_free(void) { #ifdef GGML_USE_MPI ggml_mpi_backend_free(); #endif + ggml_quantize_free(); } int64_t llama_time_us(void) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 22a7856d4..55ce14e0d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -16,39 +16,37 @@ #include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { + // static RNG initialization (revisit if n_threads stops being constant) + static const size_t n_threads = std::thread::hardware_concurrency(); + static std::vector generators = []() { + std::random_device rd; + std::vector vec; + vec.reserve(n_threads); + //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed + for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); } + return vec; + }(); + size_t size = ggml_nelements(tensor); std::vector data(size); -#if 0 - static std::default_random_engine generator(1234); - std::uniform_real_distribution distribution(min, max); - - for (size_t i = 0; i < size; i++) { - data[i] = distribution(generator); - } -#else - auto init_thread = [&](size_t start, size_t end) { - std::random_device rd; - std::default_random_engine generator(rd()); + auto init_thread = [&](size_t ith, size_t start, size_t end) { std::uniform_real_distribution distribution(min, max); - for (size_t i = start; i < end; i++) { - data[i] = distribution(generator); + data[i] = distribution(generators[ith]); } }; - size_t n_threads = std::thread::hardware_concurrency(); std::vector threads; threads.reserve(n_threads); for (size_t i = 0; i < n_threads; i++) { size_t start = i*size/n_threads; size_t end = (i+1)*size/n_threads; - threads.emplace_back(init_thread, start, end); + threads.emplace_back(init_thread, i, start, end); } for (auto & t : threads) { t.join(); } -#endif if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); @@ -56,7 +54,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); std::vector dataq(ggml_row_size(tensor->type, size)); int64_t hist[16]; - ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, nullptr); + std::vector imatrix(tensor->ne[0], 1.0f); // dummy importance matrix + const float * im = imatrix.data(); + if (!ggml_quantize_requires_imatrix(tensor->type)) { + // when the imatrix is optional, we want to test both quantization with and without imatrix + // use one of the random numbers to decide + if (data[0] > 0.5f*(min + max)) { + im = nullptr; + } + } + ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, im); ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { // This is going to create some weird integers though. @@ -1472,7 +1479,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q8_0, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K + GGML_TYPE_Q6_K, + GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, }; // unary ops @@ -1752,6 +1760,8 @@ int main(int argc, char ** argv) { return 1; } + ggml_quantize_free(); + printf("\033[1;32mOK\033[0m\n"); return 0; }