minor : spacing

This commit is contained in:
Georgi Gerganov 2024-03-22 15:24:57 +02:00 committed by GitHub
parent 2605c139a6
commit 12aa74ba7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -101,7 +101,7 @@ struct TransformerWeights {
std::vector<float> wcls; std::vector<float> wcls;
}; };
static void alloc_weights(TransformerWeights* w, const Config* p, bool shared_weights) { static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) {
const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads; const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads;
try { try {
w->token_embedding_table.resize(p->vocab_size * p->dim); w->token_embedding_table.resize(p->vocab_size * p->dim);
@ -144,12 +144,12 @@ static void alloc_weights(TransformerWeights* w, const Config* p, bool shared_we
LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
} }
} }
catch (std::length_error&) { catch (std::length_error &) {
die("Invalid configuration. Failed to allocate memory for weights"); die("Invalid configuration. Failed to allocate memory for weights");
} }
} }
static int checkpoint_init_weights(TransformerWeights *w, const Config* p, FILE* f, bool shared_weights) { static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) {
if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1; if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1;
if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1; if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1;
if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1; if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1;
@ -173,7 +173,7 @@ static int checkpoint_init_weights(TransformerWeights *w, const Config* p, FILE*
fseek(f, 0, SEEK_END); fseek(f, 0, SEEK_END);
auto end = ftell(f); auto end = ftell(f);
if (curr != end) { if (curr != end) {
LOG("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n",__func__, curr, end); LOG("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end);
return 1; return 1;
} }
@ -225,6 +225,7 @@ struct my_llama_hparams {
uint32_t n_head_kv = 32; uint32_t n_head_kv = 32;
uint32_t n_layer = 32; uint32_t n_layer = 32;
uint32_t n_rot = 64; uint32_t n_rot = 64;
bool operator!=(const my_llama_hparams& other) const { bool operator!=(const my_llama_hparams& other) const {
return memcmp(this, &other, sizeof(my_llama_hparams)); return memcmp(this, &other, sizeof(my_llama_hparams));
} }
@ -523,9 +524,9 @@ static std::string llama_escape_whitespaces(const std::string & text) {
return out.str(); return out.str();
} }
static void load_vocab(const char *filename, const Config *config, struct llama_vocab *vocab) { static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) {
if (is_ggml_file(filename)) { if (is_ggml_file(filename)) {
LOG("%s: Loading vocabulary from gguf file %s\n",__func__,filename); LOG("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
struct ggml_context * ctx_data = NULL; struct ggml_context * ctx_data = NULL;
struct gguf_init_params params = { struct gguf_init_params params = {
@ -573,7 +574,7 @@ static void load_vocab(const char *filename, const Config *config, struct llama_
gguf_free(ctx); gguf_free(ctx);
} else { } else {
// assume llama2.c vocabulary // assume llama2.c vocabulary
LOG("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n",__func__,filename); LOG("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename);
llama_file file(filename, "rb"); llama_file file(filename, "rb");
if (!file.fp) { if (!file.fp) {
die_fmt("%s: %s", strerror(errno), filename); die_fmt("%s: %s", strerror(errno), filename);
@ -643,6 +644,7 @@ static void save_as_llama_model(
// for rms-att-weight // for rms-att-weight
int row_length = model->hparams.n_embd; int row_length = model->hparams.n_embd;
int n_ff = model->hparams.n_ff; int n_ff = model->hparams.n_ff;
const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv; const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv;
for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
@ -877,17 +879,26 @@ int main(int argc, char ** argv) {
Config config; Config config;
TransformerWeights weights = {}; TransformerWeights weights = {};
{ {
LOG("%s: Loading llama2c model from %s\n",__func__,params.fn_llama2c_model); LOG("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model);
FILE *file = fopen(params.fn_llama2c_model, "r"); FILE *file = fopen(params.fn_llama2c_model, "r");
if (!file) { LOG("%s: Unable to open the checkpoint file %s!\n",__func__,params.fn_llama2c_model); return 1; } if (!file) {
LOG("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model);
return 1;
}
// read in the config header // read in the config header
if (fread(&config, sizeof(Config), 1, file) != 1) { LOG("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model); return 1; } if (fread(&config, sizeof(Config), 1, file) != 1) {
LOG("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model);
return 1;
}
auto shared_weights = config.vocab_size > 0; auto shared_weights = config.vocab_size > 0;
config.vocab_size = abs(config.vocab_size); config.vocab_size = abs(config.vocab_size);
// read in the Transformer weights // read in the Transformer weights
alloc_weights(&weights, &config, shared_weights); alloc_weights(&weights, &config, shared_weights);
if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { LOG("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model); return 1; } if (checkpoint_init_weights(&weights, &config, file, shared_weights)) {
LOG("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model);
return 1;
}
fclose(file); fclose(file);
} }
@ -904,7 +915,9 @@ int main(int argc, char ** argv) {
model.hparams.n_head_kv = config.n_kv_heads; model.hparams.n_head_kv = config.n_kv_heads;
model.hparams.n_layer = config.n_layers; //params.n_layer; model.hparams.n_layer = config.n_layers; //params.n_layer;
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
print_params(&model.hparams); print_params(&model.hparams);
struct ggml_init_params lcparams; struct ggml_init_params lcparams;
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
lcparams.mem_buffer = NULL; lcparams.mem_buffer = NULL;
@ -916,7 +929,7 @@ int main(int argc, char ** argv) {
model.name = basename(params.fn_llama2c_model); model.name = basename(params.fn_llama2c_model);
save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model); save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model);
LOG("%s: Saving llama.c model file %s in ggml format at %s\n",__func__, params.fn_llama2c_model, params.fn_llama2c_output_model); LOG("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model);
ggml_free(model.ctx); ggml_free(model.ctx);
return 0; return 0;