Fix convert script, warnings alpaca instructions, default params

This commit is contained in:
Georgi Gerganov 2023-03-21 17:59:16 +02:00
parent 715d292ee0
commit 3bfa3b43b7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 23 additions and 17 deletions

View File

@ -193,15 +193,15 @@ First, download the `ggml` Alpaca model into the `./models` folder:
``` ```
# use one of these # use one of these
# TODO: add a script to simplify the download # TODO: add a script to simplify the download
curl -o ggml2-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1 curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
curl -o ggml2-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1 curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
curl -o ggml2-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1 curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
``` ```
Now run the `main` tool like this: Now run the `main` tool like this:
``` ```
./main -m ./models/ggml2-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins ./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins
``` ```
Sample run: Sample run:

View File

@ -3,4 +3,4 @@
# Temporary script - will be removed in the future # Temporary script - will be removed in the future
# #
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.96 --repeat_penalty 1 -t 7 ./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.2 --repeat_penalty 1 -t 7

View File

@ -28,8 +28,8 @@ def parse_args():
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
parser.add_argument('dir_model', help='directory containing the model checkpoint') parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
parser.add_argument('vocab_only', type=bool, default=False, help='only write vocab to file') parser.add_argument('vocab_only', help='only write vocab to file', type=int, default=0, nargs='?')
return parser.parse_args() return parser.parse_args()
def get_n_parts(dim): def get_n_parts(dim):
@ -135,6 +135,8 @@ def main():
hparams, tokenizer = load_hparams_and_tokenizer(dir_model) hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
print(args)
# if only writing vocab to file # if only writing vocab to file
if args.vocab_only: if args.vocab_only:

View File

@ -165,12 +165,20 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
// load vocab // load vocab
{ {
std::string word; std::string word;
std::vector<char> tmp(64);
for (int i = 0; i < model.hparams.n_vocab; i++) { for (int i = 0; i < model.hparams.n_vocab; i++) {
uint32_t len; uint32_t len;
fin.read((char *) &len, sizeof(len)); fin.read((char *) &len, sizeof(len));
word.resize(len); word.resize(len);
fin.read((char *) word.data(), len); if (len > 0) {
tmp.resize(len);
fin.read(tmp.data(), len);
word.assign(tmp.data(), len);
} else {
word.clear();
}
float score; float score;
fin.read((char *) &score, sizeof(score)); fin.read((char *) &score, sizeof(score));
@ -178,10 +186,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
vocab.score[i] = score; vocab.score[i] = score;
//if (i < 30000) {
// fprintf(stderr, "%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
//}
} }
} }
@ -974,7 +978,7 @@ int main(int argc, char ** argv) {
n_past += embd.size(); n_past += embd.size();
embd.clear(); embd.clear();
if (embd_inp.size() <= input_consumed) { if ((int) embd_inp.size() <= input_consumed) {
// out of user input, sample next token // out of user input, sample next token
const float top_k = params.top_k; const float top_k = params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
@ -1011,7 +1015,7 @@ int main(int argc, char ** argv) {
--remaining_tokens; --remaining_tokens;
} else { } else {
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > input_consumed) { while ((int) embd_inp.size() > input_consumed) {
embd.push_back(embd_inp[input_consumed]); embd.push_back(embd_inp[input_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[input_consumed]); last_n_tokens.push_back(embd_inp[input_consumed]);
@ -1036,7 +1040,7 @@ int main(int argc, char ** argv) {
// in interactive mode, and not currently processing queued inputs; // in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more // check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) { if (params.interactive && (int) embd_inp.size() <= input_consumed) {
// check for reverse prompt // check for reverse prompt
for (auto antiprompt_inp : antipromptv_inp) { for (auto antiprompt_inp : antipromptv_inp) {
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {