diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 5687476cd..ea06fcdbf 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -26,6 +26,7 @@ struct StatParams { std::string ofile = "imatrix.dat"; int n_output_frequency = 10; int verbosity = 1; + int keep_every = 0; bool collect_output_weight = false; }; @@ -42,6 +43,9 @@ private: int m_last_call = 0; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id + // + void save_imatrix(const char * file_name) const; + void keep_imatrix(int ncall) const; }; bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { @@ -117,6 +121,9 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (m_last_call % m_params.n_output_frequency == 0) { save_imatrix(); } + if (m_params.keep_every > 0 && m_last_call%m_params.keep_every == 0) { + keep_imatrix(m_last_call); + } } } } else { @@ -143,6 +150,9 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (m_last_call % m_params.n_output_frequency == 0) { save_imatrix(); } + if (m_params.keep_every > 0 && m_last_call%m_params.keep_every == 0) { + keep_imatrix(m_last_call); + } } } @@ -150,7 +160,18 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } void IMatrixCollector::save_imatrix() const { - const char * fname = m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(); + save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str()); +} + +void IMatrixCollector::keep_imatrix(int ncall) const { + auto file_name = m_params.ofile; + if (file_name.empty()) file_name = "imatrix.dat"; + file_name += ".at_"; + file_name += std::to_string(ncall); + save_imatrix(file_name.c_str()); +} + +void IMatrixCollector::save_imatrix(const char * fname) const { std::ofstream out(fname, std::ios::binary); int n_entries = m_stats.size(); out.write((const char*)&n_entries, sizeof(n_entries)); @@ -400,6 +421,8 @@ int main(int argc, char ** argv) { sparams.verbosity = std::stoi(argv[++iarg]); } else if (arg == "--no-ppl") { compute_ppl = false; + } else if (arg == "--keep-imatrix") { + sparams.keep_every = std::stoi(argv[++iarg]); } else { args.push_back(argv[iarg]); }