gguf-split : improve --split and --merge logic (#9619)

* make sure params --split and --merge are not specified at same time

* update gguf-split params parse logic

* Update examples/gguf-split/gguf-split.cpp

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Zhenwei Jin 2024-10-02 15:21:57 +08:00 committed by GitHub
parent 148844fe97
commit 76b37d1541
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,12 +22,20 @@
#endif #endif
enum split_operation : uint8_t { enum split_operation : uint8_t {
SPLIT_OP_SPLIT, OP_NONE,
SPLIT_OP_MERGE, OP_SPLIT,
OP_MERGE,
};
enum split_mode : uint8_t {
MODE_NONE,
MODE_TENSOR,
MODE_SIZE,
}; };
struct split_params { struct split_params {
split_operation operation = SPLIT_OP_SPLIT; split_operation operation = OP_NONE;
split_mode mode = MODE_NONE;
size_t n_bytes_split = 0; size_t n_bytes_split = 0;
int n_split_tensors = 128; int n_split_tensors = 128;
std::string input; std::string input;
@ -87,59 +95,52 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
} }
bool arg_found = false; bool arg_found = false;
bool is_op_set = false;
bool is_mode_set = false;
if (arg == "-h" || arg == "--help") { if (arg == "-h" || arg == "--help") {
split_print_usage(argv[0]); split_print_usage(argv[0]);
exit(0); exit(0);
} } else if (arg == "--version") {
if (arg == "--version") {
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
exit(0); exit(0);
} } else if (arg == "--dry-run") {
if (arg == "--dry-run") {
arg_found = true; arg_found = true;
params.dry_run = true; params.dry_run = true;
} } else if (arg == "--no-tensor-first-split") {
if (arg == "--no-tensor-first-split") {
arg_found = true; arg_found = true;
params.no_tensor_first_split = true; params.no_tensor_first_split = true;
} } else if (arg == "--merge") {
if (is_op_set) {
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
}
if (arg == "--merge") {
arg_found = true; arg_found = true;
is_op_set = true; if (params.operation != OP_NONE && params.operation != OP_MERGE) {
params.operation = SPLIT_OP_MERGE; throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
} }
if (arg == "--split") { params.operation = OP_MERGE;
} else if (arg == "--split") {
arg_found = true; arg_found = true;
is_op_set = true; if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
params.operation = SPLIT_OP_SPLIT; throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
} }
params.operation = OP_SPLIT;
if (is_mode_set) { } else if (arg == "--split-max-tensors") {
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
}
if (arg == "--split-max-tensors") {
if (++arg_idx >= argc) { if (++arg_idx >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
arg_found = true; arg_found = true;
is_mode_set = true; if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
}
params.mode = MODE_TENSOR;
params.n_split_tensors = atoi(argv[arg_idx]); params.n_split_tensors = atoi(argv[arg_idx]);
} } else if (arg == "--split-max-size") {
if (arg == "--split-max-size") {
if (++arg_idx >= argc) { if (++arg_idx >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
arg_found = true; arg_found = true;
is_mode_set = true; if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
}
params.mode = MODE_SIZE;
params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]); params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
} }
@ -148,6 +149,15 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
} }
} }
// the operation is split if not specified
if (params.operation == OP_NONE) {
params.operation = OP_SPLIT;
}
// the split mode is by tensor if not specified
if (params.mode == MODE_NONE) {
params.mode = MODE_TENSOR;
}
if (invalid_param) { if (invalid_param) {
throw std::invalid_argument("error: invalid parameter for argument: " + arg); throw std::invalid_argument("error: invalid parameter for argument: " + arg);
} }
@ -265,13 +275,15 @@ struct split_strategy {
} }
bool should_split(int i_tensor, size_t next_size) { bool should_split(int i_tensor, size_t next_size) {
if (params.n_bytes_split > 0) { if (params.mode == MODE_SIZE) {
// split by max size per file // split by max size per file
return next_size > params.n_bytes_split; return next_size > params.n_bytes_split;
} else { } else if (params.mode == MODE_TENSOR) {
// split by number of tensors per file // split by number of tensors per file
return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0; return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
} }
// should never happen
GGML_ABORT("invalid mode");
} }
void print_info() { void print_info() {
@ -559,9 +571,9 @@ int main(int argc, const char ** argv) {
split_params_parse(argc, argv, params); split_params_parse(argc, argv, params);
switch (params.operation) { switch (params.operation) {
case SPLIT_OP_SPLIT: gguf_split(params); case OP_SPLIT: gguf_split(params);
break; break;
case SPLIT_OP_MERGE: gguf_merge(params); case OP_MERGE: gguf_merge(params);
break; break;
default: split_print_usage(argv[0]); default: split_print_usage(argv[0]);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);