mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 22:59:24 +01:00
Merge branch 'master' into xsn/fix_lora
This commit is contained in:
commit
a1666aaaca
18
Makefile
18
Makefile
@ -14,6 +14,7 @@ BUILD_TARGETS = \
|
||||
llama-finetune \
|
||||
llama-gbnf-validator \
|
||||
llama-gguf \
|
||||
llama-gguf-hash \
|
||||
llama-gguf-split \
|
||||
llama-gritlm \
|
||||
llama-imatrix \
|
||||
@ -1178,6 +1179,23 @@ llama-gguf: examples/gguf/gguf.cpp \
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
examples/gguf-hash/deps/sha1/sha1.o: \
|
||||
examples/gguf-hash/deps/sha1/sha1.c
|
||||
$(CC) $(CFLAGS) -Iexamples/gguf-hash/deps -c $< -o $@
|
||||
|
||||
examples/gguf-hash/deps/xxhash/xxhash.o: \
|
||||
examples/gguf-hash/deps/xxhash/xxhash.c
|
||||
$(CC) $(CFLAGS) -Iexamples/gguf-hash/deps -c $< -o $@
|
||||
|
||||
examples/gguf-hash/deps/sha256/sha256.o: \
|
||||
examples/gguf-hash/deps/sha256/sha256.c
|
||||
$(CC) $(CFLAGS) -Iexamples/gguf-hash/deps -c $< -o $@
|
||||
|
||||
llama-gguf-hash: examples/gguf-hash/gguf-hash.cpp examples/gguf-hash/deps/sha1/sha1.o examples/gguf-hash/deps/xxhash/xxhash.o examples/gguf-hash/deps/sha256/sha256.o\
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/gguf-hash/deps -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
llama-gguf-split: examples/gguf-split/gguf-split.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
|
44
README.md
44
README.md
@ -131,6 +131,7 @@ Typically finetunes of the base models below are supported as well.
|
||||
- Zig: [deins/llama.cpp.zig](https://github.com/Deins/llama.cpp.zig)
|
||||
- Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart)
|
||||
- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326)
|
||||
- Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp)
|
||||
|
||||
**UI:**
|
||||
|
||||
@ -391,28 +392,21 @@ The `grammars/` folder contains a handful of sample grammars. To write your own,
|
||||
|
||||
For authoring more complex JSON grammars, you can also check out https://grammar.intrinsiclabs.ai/, a browser app that lets you write TypeScript interfaces which it compiles to GBNF grammars that you can save for local use. Note that the app is built and maintained by members of the community, please file any issues or FRs on [its repo](http://github.com/intrinsiclabsai/gbnfgen) and not this one.
|
||||
|
||||
### Obtaining and using the Facebook LLaMA 2 model
|
||||
## Build
|
||||
|
||||
- Refer to [Facebook's LLaMA download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) if you want to access the model data.
|
||||
- Alternatively, if you want to save time and space, you can download already converted and quantized models from [TheBloke](https://huggingface.co/TheBloke), including:
|
||||
- [LLaMA 2 7B base](https://huggingface.co/TheBloke/Llama-2-7B-GGUF)
|
||||
- [LLaMA 2 13B base](https://huggingface.co/TheBloke/Llama-2-13B-GGUF)
|
||||
- [LLaMA 2 70B base](https://huggingface.co/TheBloke/Llama-2-70B-GGUF)
|
||||
- [LLaMA 2 7B chat](https://huggingface.co/TheBloke/Llama-2-7B-chat-GGUF)
|
||||
- [LLaMA 2 13B chat](https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF)
|
||||
- [LLaMA 2 70B chat](https://huggingface.co/TheBloke/Llama-2-70B-chat-GGUF)
|
||||
Please refer to [Build llama.cpp locally](./docs/build.md)
|
||||
|
||||
### Seminal papers and background on the models
|
||||
## Supported backends
|
||||
|
||||
If your issue is with model generation quality, then please at least scan the following links and papers to understand the limitations of LLaMA models. This is especially important when choosing an appropriate model size and appreciating both the significant and subtle differences between LLaMA models and ChatGPT:
|
||||
- LLaMA:
|
||||
- [Introducing LLaMA: A foundational, 65-billion-parameter large language model](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/)
|
||||
- [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)
|
||||
- GPT-3
|
||||
- [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
|
||||
- GPT-3.5 / InstructGPT / ChatGPT:
|
||||
- [Aligning language models to follow instructions](https://openai.com/research/instruction-following)
|
||||
- [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)
|
||||
| Backend | Target devices |
|
||||
| --- | --- |
|
||||
| [Metal](./docs/build.md#metal-build) | Apple Silicon |
|
||||
| [BLAS](./docs/build.md#blas-build) | All |
|
||||
| [BLIS](./docs/backend/BLIS.md) | All |
|
||||
| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU |
|
||||
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
|
||||
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
|
||||
| [Vulkan](./docs/build.md#vulkan) | GPU |
|
||||
|
||||
## Tools
|
||||
|
||||
@ -460,3 +454,15 @@ To learn more how to measure perplexity using llama.cpp, [read this documentatio
|
||||
- [Build on Android](./docs/android.md)
|
||||
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
|
||||
- [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks)
|
||||
|
||||
**Seminal papers and background on the models**
|
||||
|
||||
If your issue is with model generation quality, then please at least scan the following links and papers to understand the limitations of LLaMA models. This is especially important when choosing an appropriate model size and appreciating both the significant and subtle differences between LLaMA models and ChatGPT:
|
||||
- LLaMA:
|
||||
- [Introducing LLaMA: A foundational, 65-billion-parameter large language model](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/)
|
||||
- [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)
|
||||
- GPT-3
|
||||
- [Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
|
||||
- GPT-3.5 / InstructGPT / ChatGPT:
|
||||
- [Aligning language models to follow instructions](https://openai.com/research/instruction-following)
|
||||
- [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155)
|
||||
|
@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
|
||||
// CLI argument parsing
|
||||
//
|
||||
|
||||
void gpt_params_handle_hf_token(gpt_params & params) {
|
||||
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
|
||||
params.hf_token = std::getenv("HF_TOKEN");
|
||||
}
|
||||
}
|
||||
|
||||
void gpt_params_handle_model_default(gpt_params & params) {
|
||||
if (!params.hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
|
||||
gpt_params_handle_model_default(params);
|
||||
|
||||
gpt_params_handle_hf_token(params);
|
||||
|
||||
if (params.escape) {
|
||||
string_process_escapes(params.prompt);
|
||||
string_process_escapes(params.input_prefix);
|
||||
@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.model_url = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "-hft" || arg == "--hf-token") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
params.hf_token = argv[i];
|
||||
return true;
|
||||
}
|
||||
if (arg == "-hfr" || arg == "--hf-repo") {
|
||||
CHECK_ARG
|
||||
params.hf_repo = argv[i];
|
||||
@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" });
|
||||
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
|
||||
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
|
||||
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
|
||||
|
||||
options.push_back({ "retrieval" });
|
||||
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
|
||||
@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
llama_model * model = nullptr;
|
||||
|
||||
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
|
||||
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
|
||||
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
|
||||
} else if (!params.model_url.empty()) {
|
||||
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
|
||||
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
|
||||
} else {
|
||||
model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
}
|
||||
@ -2200,7 +2217,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
static bool llama_download_file(const std::string & url, const std::string & path) {
|
||||
static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
||||
|
||||
// Initialize libcurl
|
||||
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
@ -2215,6 +2232,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
// Check if hf-token or bearer-token was specified
|
||||
if (!hf_token.empty()) {
|
||||
std::string auth_header = "Authorization: Bearer ";
|
||||
auth_header += hf_token.c_str();
|
||||
struct curl_slist *http_headers = NULL;
|
||||
http_headers = curl_slist_append(http_headers, auth_header.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
@ -2410,6 +2436,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat
|
||||
struct llama_model * llama_load_model_from_url(
|
||||
const char * model_url,
|
||||
const char * path_model,
|
||||
const char * hf_token,
|
||||
const struct llama_model_params & params) {
|
||||
// Basic validation of the model_url
|
||||
if (!model_url || strlen(model_url) == 0) {
|
||||
@ -2417,7 +2444,7 @@ struct llama_model * llama_load_model_from_url(
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!llama_download_file(model_url, path_model)) {
|
||||
if (!llama_download_file(model_url, path_model, hf_token)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -2465,14 +2492,14 @@ struct llama_model * llama_load_model_from_url(
|
||||
// Prepare download in parallel
|
||||
std::vector<std::future<bool>> futures_download;
|
||||
for (int idx = 1; idx < n_split; idx++) {
|
||||
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
|
||||
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
|
||||
char split_path[PATH_MAX] = {0};
|
||||
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
|
||||
|
||||
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
||||
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
||||
|
||||
return llama_download_file(split_url, split_path);
|
||||
return llama_download_file(split_url, split_path, hf_token);
|
||||
}, idx));
|
||||
}
|
||||
|
||||
@ -2491,6 +2518,7 @@ struct llama_model * llama_load_model_from_hf(
|
||||
const char * repo,
|
||||
const char * model,
|
||||
const char * path_model,
|
||||
const char * hf_token,
|
||||
const struct llama_model_params & params) {
|
||||
// construct hugging face model url:
|
||||
//
|
||||
@ -2506,7 +2534,7 @@ struct llama_model * llama_load_model_from_hf(
|
||||
model_url += "/resolve/main/";
|
||||
model_url += model;
|
||||
|
||||
return llama_load_model_from_url(model_url.c_str(), path_model, params);
|
||||
return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
|
||||
}
|
||||
|
||||
#else
|
||||
@ -2514,6 +2542,7 @@ struct llama_model * llama_load_model_from_hf(
|
||||
struct llama_model * llama_load_model_from_url(
|
||||
const char * /*model_url*/,
|
||||
const char * /*path_model*/,
|
||||
const char * /*hf_token*/,
|
||||
const struct llama_model_params & /*params*/) {
|
||||
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
||||
return nullptr;
|
||||
@ -2523,6 +2552,7 @@ struct llama_model * llama_load_model_from_hf(
|
||||
const char * /*repo*/,
|
||||
const char * /*model*/,
|
||||
const char * /*path_model*/,
|
||||
const char * /*hf_token*/,
|
||||
const struct llama_model_params & /*params*/) {
|
||||
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
||||
return nullptr;
|
||||
|
@ -108,6 +108,7 @@ struct gpt_params {
|
||||
std::string model_draft = ""; // draft model for speculative decoding
|
||||
std::string model_alias = "unknown"; // model alias
|
||||
std::string model_url = ""; // model url to download
|
||||
std::string hf_token = ""; // HF token
|
||||
std::string hf_repo = ""; // HF repo
|
||||
std::string hf_file = ""; // HF file
|
||||
std::string prompt = "";
|
||||
@ -256,6 +257,7 @@ struct gpt_params {
|
||||
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
||||
};
|
||||
|
||||
void gpt_params_handle_hf_token(gpt_params & params);
|
||||
void gpt_params_handle_model_default(gpt_params & params);
|
||||
|
||||
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
|
||||
@ -311,8 +313,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
|
||||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
|
||||
|
||||
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
|
||||
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
|
||||
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
|
||||
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
|
||||
|
||||
// Batch utils
|
||||
|
||||
|
@ -487,6 +487,9 @@ class Model:
|
||||
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
|
||||
res = "jina-v2-code"
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
|
||||
# ref: https://huggingface.co/LumiOpen/Viking-7B
|
||||
res = "viking"
|
||||
@ -3176,6 +3179,190 @@ class JaisModel(Model):
|
||||
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
|
||||
|
||||
|
||||
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CHATGLM
|
||||
|
||||
def set_vocab_chatglm3(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
toktypes: list[int] = []
|
||||
scores: list[float] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
for token_id in range(vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
if token_id == 0:
|
||||
piece = "<unk>"
|
||||
elif token_id == 1:
|
||||
piece = "<bos>"
|
||||
elif token_id == 2:
|
||||
piece = "<eos>"
|
||||
|
||||
text = piece.encode("utf-8")
|
||||
score = 0.0
|
||||
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
|
||||
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
|
||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
|
||||
score = tokenizer.tokenizer.sp_model.get_score(token_id)
|
||||
|
||||
if len(piece) == 0:
|
||||
text = f"[PAD{token_id}]".encode("utf-8")
|
||||
|
||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
|
||||
if piece in special_tokens:
|
||||
# show special tokens in prompt
|
||||
toktype = SentencePieceTokenTypes.USER_DEFINED
|
||||
else:
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
continue
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.tokenizer.sp_model.is_control(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
# glm3 needs prefix and suffix formatted as:
|
||||
# prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"
|
||||
self.gguf_writer.add_tokenizer_pre("chatglm-spm")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
byte_encoder = bytes_to_unicode()
|
||||
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
|
||||
|
||||
@staticmethod
|
||||
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
|
||||
parts = [bytes([b]) for b in token]
|
||||
while True:
|
||||
min_idx = None
|
||||
min_rank = None
|
||||
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||
if rank is not None and (min_rank is None or rank < min_rank):
|
||||
min_idx = i
|
||||
min_rank = rank
|
||||
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
|
||||
break
|
||||
assert min_idx is not None
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
|
||||
return parts
|
||||
|
||||
def set_vocab(self):
|
||||
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
|
||||
self.set_vocab_chatglm3()
|
||||
return
|
||||
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams["padded_vocab_size"]
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
|
||||
assert len(merged) >= 2 and len(merged) <= 7
|
||||
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
|
||||
|
||||
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
if tokenizer.added_tokens_decoder[i].special:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
|
||||
special_vocab.merges = merges
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
|
||||
# this one is usually not in config.json anyway
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
|
||||
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
|
||||
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
|
||||
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
|
||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_rope_dimension_count(64)
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
rope_freq = 10000
|
||||
if "rope_ratio" in self.hparams:
|
||||
rope_freq = rope_freq * self.hparams["rope_ratio"]
|
||||
self.gguf_writer.add_rope_freq_base(rope_freq)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.endswith(".rotary_pos_emb.inv_freq"):
|
||||
return []
|
||||
|
||||
name = name.removeprefix("transformer.")
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
@ -85,7 +85,7 @@ Building the program with BLAS support may lead to some performance improvements
|
||||
|
||||
### Accelerate Framework:
|
||||
|
||||
This is only available on Mac PCs and it's enabled by default. You can just build using the normal instructions.
|
||||
This is only available on Mac PCs and it's enabled by default. You can just build using the normal instructions.
|
||||
|
||||
### OpenBLAS:
|
||||
|
||||
|
@ -23,6 +23,7 @@ else()
|
||||
add_subdirectory(export-lora)
|
||||
add_subdirectory(finetune)
|
||||
add_subdirectory(gbnf-validator)
|
||||
add_subdirectory(gguf-hash)
|
||||
add_subdirectory(gguf-split)
|
||||
add_subdirectory(gguf)
|
||||
add_subdirectory(gritlm)
|
||||
|
@ -87,4 +87,4 @@ The LORA rank can be configured for each model tensor type separately with these
|
||||
|
||||
The LORA rank of 'norm' tensors should always be 1.
|
||||
|
||||
To see all available options use `finetune --help`.
|
||||
To see all available options use `llama-finetune --help`.
|
||||
|
@ -8,7 +8,7 @@ if [[ ! $LLAMA_MODEL_DIR ]]; then LLAMA_MODEL_DIR="./models"; fi
|
||||
if [[ ! $LLAMA_TRAINING_DIR ]]; then LLAMA_TRAINING_DIR="."; fi
|
||||
|
||||
# MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2-q8_0.gguf" # This is the model the readme uses.
|
||||
MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2.gguf" # An f16 model. Note in this case with "-g", you get an f32-format .BIN file that isn't yet supported if you use it with "main --lora" with GPU inferencing.
|
||||
MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2.gguf" # An f16 model. Note in this case with "-g", you get an f32-format .BIN file that isn't yet supported if you use it with "llama-cli --lora" with GPU inferencing.
|
||||
|
||||
while getopts "dg" opt; do
|
||||
case $opt in
|
||||
|
15
examples/gguf-hash/CMakeLists.txt
Normal file
15
examples/gguf-hash/CMakeLists.txt
Normal file
@ -0,0 +1,15 @@
|
||||
set(TARGET llama-gguf-hash)
|
||||
add_executable(${TARGET} gguf-hash.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
# clibs dependencies
|
||||
include_directories(deps/)
|
||||
add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h)
|
||||
target_link_libraries(${TARGET} PRIVATE xxhash)
|
||||
add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h)
|
||||
target_link_libraries(${TARGET} PRIVATE sha1)
|
||||
add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h)
|
||||
target_link_libraries(${TARGET} PRIVATE sha256)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
206
examples/gguf-hash/README.md
Normal file
206
examples/gguf-hash/README.md
Normal file
@ -0,0 +1,206 @@
|
||||
|
||||
# llama-gguf-hash
|
||||
|
||||
CLI to hash GGUF files to detect difference on a per model and per tensor level.
|
||||
|
||||
**Command line options:**
|
||||
|
||||
- `--help`: display help message
|
||||
- `--xxh64`: use xhash 64bit hash mode (default)
|
||||
- `--sha1`: use sha1
|
||||
- `--uuid`: use uuid
|
||||
- `--sha256`: use sha256
|
||||
- `--all`: use all hash
|
||||
- `--no-layer`: exclude per layer hash
|
||||
- `--uuid`: generate UUIDv5 ID
|
||||
- `-c`, `--check <manifest>`: verify against a manifest
|
||||
|
||||
## About
|
||||
|
||||
While most POSIX systems already have hash checking programs like sha256sum, it
|
||||
is designed to check entire files. This is not ideal for our purpose if we want
|
||||
to check for consistency of the tensor data even if the metadata content of the
|
||||
gguf KV store has been updated.
|
||||
|
||||
This program is designed to hash a gguf tensor payload on a 'per tensor layer'
|
||||
in addition to a 'entire tensor model' hash. The intent is that the entire
|
||||
tensor layer can be checked first but if there is any detected inconsistencies,
|
||||
then the per tensor hash can be used to narrow down the specific tensor layer
|
||||
that has inconsistencies.
|
||||
|
||||
For Maintainers:
|
||||
- Detection of tensor inconsistency during development and automated tests
|
||||
- This is served by xxh64 which is fast
|
||||
- This is also served by having per tensor layer to assist in narrowing down
|
||||
the location of the faulty tensor layer
|
||||
- This is also served by sha1 which is much slower but more widely supported
|
||||
|
||||
For Model Creators:
|
||||
- Optional consistent UUID generation based on model tensor content
|
||||
- This is served by UUIDv5 which is useful for databases keys
|
||||
- llama.cpp UUIDv5 Namespace: `ef001206-dadc-5f6d-a15f-3359e577d4e5`
|
||||
- Made via UUIDv5 URL namespace of `en.wikipedia.org/wiki/Llama.cpp`
|
||||
|
||||
For Model Users:
|
||||
- Assurance of tensor layer integrity even if metadata was updated
|
||||
- This is served by sha256 which is still considered very secure as of 2024
|
||||
|
||||
### Design Note
|
||||
|
||||
- The default behavior of this program if no arguments is provided is to hash
|
||||
using xxhash's xxh32 mode because it is very fast and is primarily targeted
|
||||
towards maintainers who may want to use this in automated tests.
|
||||
- xxhash support xxh32 and xxh128 for 32bit hash and 128bit hash respectively
|
||||
however we picked 64bit xxhash as most computers are 64bit as of 2024 and thus
|
||||
would have a better affinity to calculating hash that is 64bit in size.
|
||||
|
||||
## Compile Example
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_FATAL_WARNINGS=ON
|
||||
make -C build clean
|
||||
make -C build llama-gguf-hash VERBOSE=1
|
||||
./build/bin/llama-gguf-hash test.gguf
|
||||
./build/bin/llama-gguf-hash --xxh64 test.gguf
|
||||
./build/bin/llama-gguf-hash --sha1 test.gguf
|
||||
./build/bin/llama-gguf-hash --uuid test.gguf
|
||||
./build/bin/llama-gguf-hash --sha256 test.gguf
|
||||
```
|
||||
|
||||
## Generation and Verification Example
|
||||
|
||||
To generate we may use this command
|
||||
|
||||
```bash
|
||||
./llama-gguf-hash --all test.gguf > test.gguf.manifest
|
||||
```
|
||||
|
||||
Which would generate a manifest that looks like below, which contains multiple hash type and per tensor layer hashes as well
|
||||
(This excludes UUID as that is an ID not a hash)
|
||||
|
||||
```bash
|
||||
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0
|
||||
sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0
|
||||
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0
|
||||
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1
|
||||
sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1
|
||||
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1
|
||||
xxh64 a0af5d700049693b test.gguf:tensor_2
|
||||
sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2
|
||||
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2
|
||||
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3
|
||||
sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3
|
||||
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3
|
||||
xxh64 1257733306b7992d test.gguf:tensor_4
|
||||
sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4
|
||||
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4
|
||||
xxh64 d238d16ba4711e58 test.gguf:tensor_5
|
||||
sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5
|
||||
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5
|
||||
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6
|
||||
sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6
|
||||
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6
|
||||
xxh64 c22021c29854f093 test.gguf:tensor_7
|
||||
sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7
|
||||
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7
|
||||
xxh64 936df61f5d64261f test.gguf:tensor_8
|
||||
sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8
|
||||
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8
|
||||
xxh64 93fd20c64421c081 test.gguf:tensor_9
|
||||
sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9
|
||||
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9
|
||||
xxh64 5a54d3aad816f302 test.gguf
|
||||
sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf
|
||||
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf
|
||||
```
|
||||
|
||||
We can then use the normal check command which will by default check for the highest security strength hash and verify against that:
|
||||
|
||||
```bash
|
||||
$ ./llama-gguf-hash --check test.gguf.manifest test.gguf
|
||||
manifest test.gguf.manifest sha256 sha1 xxh64
|
||||
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok
|
||||
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok
|
||||
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok
|
||||
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok
|
||||
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok
|
||||
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok
|
||||
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok
|
||||
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok
|
||||
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok
|
||||
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok
|
||||
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok
|
||||
|
||||
Verification results for test.gguf.manifest - Success
|
||||
```
|
||||
|
||||
Or we may explicitly ask for a faster hash like:
|
||||
|
||||
```bash
|
||||
$ ./llama-gguf-hash --check test.gguf.manifest --xxh64 test.gguf
|
||||
manifest test.gguf.manifest sha256 sha1 xxh64
|
||||
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok
|
||||
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok
|
||||
xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok
|
||||
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok
|
||||
xxh64 1257733306b7992d test.gguf:tensor_4 - Ok
|
||||
xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok
|
||||
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok
|
||||
xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok
|
||||
xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok
|
||||
xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok
|
||||
xxh64 5a54d3aad816f302 test.gguf - Ok
|
||||
|
||||
Verification results for test.gguf.manifest - Success
|
||||
```
|
||||
|
||||
Or maybe we want to just check that all the hash is valid:
|
||||
|
||||
```bash
|
||||
$./llama-gguf-hash --check test.gguf.manifest --all test.gguf.manifest
|
||||
manifest test.gguf.manifest sha256 sha1 xxh64
|
||||
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok
|
||||
sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 - Ok
|
||||
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok
|
||||
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok
|
||||
sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 - Ok
|
||||
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok
|
||||
xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok
|
||||
sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 - Ok
|
||||
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok
|
||||
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok
|
||||
sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 - Ok
|
||||
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok
|
||||
xxh64 1257733306b7992d test.gguf:tensor_4 - Ok
|
||||
sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 - Ok
|
||||
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok
|
||||
xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok
|
||||
sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 - Ok
|
||||
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok
|
||||
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok
|
||||
sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 - Ok
|
||||
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok
|
||||
xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok
|
||||
sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 - Ok
|
||||
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok
|
||||
xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok
|
||||
sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 - Ok
|
||||
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok
|
||||
xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok
|
||||
sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 - Ok
|
||||
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok
|
||||
xxh64 5a54d3aad816f302 test.gguf - Ok
|
||||
sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf - Ok
|
||||
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok
|
||||
|
||||
Verification results for test.gguf.manifest - Success
|
||||
```
|
||||
|
||||
|
||||
## Crypto/Hash Libraries Used
|
||||
|
||||
These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs)
|
||||
|
||||
- https://github.com/mofosyne/xxHash (From: https://github.com/Cyan4973/xxHash)
|
||||
- https://github.com/clibs/sha1/
|
||||
- https://github.com/jb55/sha256.c
|
13
examples/gguf-hash/deps/rotate-bits/package.json
Normal file
13
examples/gguf-hash/deps/rotate-bits/package.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"name": "rotate-bits",
|
||||
"version": "0.1.1",
|
||||
"repo": "jb55/rotate-bits.h",
|
||||
"description": "rotate bits",
|
||||
"keywords": ["rotl", "rotr"],
|
||||
"src": ["rotate-bits.h"],
|
||||
"license": "Public Domain",
|
||||
"development": {
|
||||
"thlorenz/tap.c": "*"
|
||||
}
|
||||
}
|
||||
|
46
examples/gguf-hash/deps/rotate-bits/rotate-bits.h
Normal file
46
examples/gguf-hash/deps/rotate-bits/rotate-bits.h
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
|
||||
#ifndef __ROTATE_DEFS_H
|
||||
#define __ROTATE_DEFS_H
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#define ROTL32(v, n) _rotl((v), (n))
|
||||
#define ROTL64(v, n) _rotl64((v), (n))
|
||||
|
||||
#define ROTR32(v, n) _rotr((v), (n))
|
||||
#define ROTR64(v, n) _rotr64((v), (n))
|
||||
|
||||
#else
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#define U8V(v) ((uint8_t)(v) & 0xFFU)
|
||||
#define U16V(v) ((uint16_t)(v) & 0xFFFFU)
|
||||
#define U32V(v) ((uint32_t)(v) & 0xFFFFFFFFU)
|
||||
#define U64V(v) ((uint64_t)(v) & 0xFFFFFFFFFFFFFFFFU)
|
||||
|
||||
#define ROTL32(v, n) \
|
||||
(U32V((uint32_t)(v) << (n)) | ((uint32_t)(v) >> (32 - (n))))
|
||||
|
||||
// tests fail if we don't have this cast...
|
||||
#define ROTL64(v, n) \
|
||||
(U64V((uint64_t)(v) << (n)) | ((uint64_t)(v) >> (64 - (n))))
|
||||
|
||||
#define ROTR32(v, n) ROTL32(v, 32 - (n))
|
||||
#define ROTR64(v, n) ROTL64(v, 64 - (n))
|
||||
|
||||
#endif
|
||||
|
||||
#define ROTL8(v, n) \
|
||||
(U8V((uint8_t)(v) << (n)) | ((uint8_t)(v) >> (8 - (n))))
|
||||
|
||||
#define ROTL16(v, n) \
|
||||
(U16V((uint16_t)(v) << (n)) | ((uint16_t)(v) >> (16 - (n))))
|
||||
|
||||
#define ROTR8(v, n) ROTL8(v, 8 - (n))
|
||||
#define ROTR16(v, n) ROTL16(v, 16 - (n))
|
||||
|
||||
#endif
|
9
examples/gguf-hash/deps/sha1/package.json
Normal file
9
examples/gguf-hash/deps/sha1/package.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"name": "sha1",
|
||||
"version": "0.0.1",
|
||||
"repo": "clibs/sha1",
|
||||
"description": "sha1 hash algorithm",
|
||||
"keywords": ["sha1", "hash"],
|
||||
"license": "public domain",
|
||||
"src": ["sha1.c", "sha1.h"]
|
||||
}
|
295
examples/gguf-hash/deps/sha1/sha1.c
Normal file
295
examples/gguf-hash/deps/sha1/sha1.c
Normal file
@ -0,0 +1,295 @@
|
||||
/*
|
||||
SHA-1 in C
|
||||
By Steve Reid <steve@edmweb.com>
|
||||
100% Public Domain
|
||||
|
||||
Test Vectors (from FIPS PUB 180-1)
|
||||
"abc"
|
||||
A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
|
||||
"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
|
||||
84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
|
||||
A million repetitions of "a"
|
||||
34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
|
||||
*/
|
||||
|
||||
/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */
|
||||
/* #define SHA1HANDSOFF * Copies data before messing with it. */
|
||||
|
||||
#define SHA1HANDSOFF
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
/* for uint32_t */
|
||||
#include <stdint.h>
|
||||
|
||||
#include "sha1.h"
|
||||
|
||||
|
||||
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
|
||||
|
||||
/* blk0() and blk() perform the initial expand. */
|
||||
/* I got the idea of expanding during the round function from SSLeay */
|
||||
#if BYTE_ORDER == LITTLE_ENDIAN
|
||||
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
|
||||
|(rol(block->l[i],8)&0x00FF00FF))
|
||||
#elif BYTE_ORDER == BIG_ENDIAN
|
||||
#define blk0(i) block->l[i]
|
||||
#else
|
||||
#error "Endianness not defined!"
|
||||
#endif
|
||||
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
|
||||
^block->l[(i+2)&15]^block->l[i&15],1))
|
||||
|
||||
/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
|
||||
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
|
||||
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
|
||||
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
|
||||
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
|
||||
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
|
||||
|
||||
|
||||
/* Hash a single 512-bit block. This is the core of the algorithm. */
|
||||
|
||||
void SHA1Transform(
|
||||
uint32_t state[5],
|
||||
const unsigned char buffer[64]
|
||||
)
|
||||
{
|
||||
uint32_t a, b, c, d, e;
|
||||
|
||||
typedef union
|
||||
{
|
||||
unsigned char c[64];
|
||||
uint32_t l[16];
|
||||
} CHAR64LONG16;
|
||||
|
||||
#ifdef SHA1HANDSOFF
|
||||
CHAR64LONG16 block[1]; /* use array to appear as a pointer */
|
||||
|
||||
memcpy(block, buffer, 64);
|
||||
#else
|
||||
/* The following had better never be used because it causes the
|
||||
* pointer-to-const buffer to be cast into a pointer to non-const.
|
||||
* And the result is written through. I threw a "const" in, hoping
|
||||
* this will cause a diagnostic.
|
||||
*/
|
||||
CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer;
|
||||
#endif
|
||||
/* Copy context->state[] to working vars */
|
||||
a = state[0];
|
||||
b = state[1];
|
||||
c = state[2];
|
||||
d = state[3];
|
||||
e = state[4];
|
||||
/* 4 rounds of 20 operations each. Loop unrolled. */
|
||||
R0(a, b, c, d, e, 0);
|
||||
R0(e, a, b, c, d, 1);
|
||||
R0(d, e, a, b, c, 2);
|
||||
R0(c, d, e, a, b, 3);
|
||||
R0(b, c, d, e, a, 4);
|
||||
R0(a, b, c, d, e, 5);
|
||||
R0(e, a, b, c, d, 6);
|
||||
R0(d, e, a, b, c, 7);
|
||||
R0(c, d, e, a, b, 8);
|
||||
R0(b, c, d, e, a, 9);
|
||||
R0(a, b, c, d, e, 10);
|
||||
R0(e, a, b, c, d, 11);
|
||||
R0(d, e, a, b, c, 12);
|
||||
R0(c, d, e, a, b, 13);
|
||||
R0(b, c, d, e, a, 14);
|
||||
R0(a, b, c, d, e, 15);
|
||||
R1(e, a, b, c, d, 16);
|
||||
R1(d, e, a, b, c, 17);
|
||||
R1(c, d, e, a, b, 18);
|
||||
R1(b, c, d, e, a, 19);
|
||||
R2(a, b, c, d, e, 20);
|
||||
R2(e, a, b, c, d, 21);
|
||||
R2(d, e, a, b, c, 22);
|
||||
R2(c, d, e, a, b, 23);
|
||||
R2(b, c, d, e, a, 24);
|
||||
R2(a, b, c, d, e, 25);
|
||||
R2(e, a, b, c, d, 26);
|
||||
R2(d, e, a, b, c, 27);
|
||||
R2(c, d, e, a, b, 28);
|
||||
R2(b, c, d, e, a, 29);
|
||||
R2(a, b, c, d, e, 30);
|
||||
R2(e, a, b, c, d, 31);
|
||||
R2(d, e, a, b, c, 32);
|
||||
R2(c, d, e, a, b, 33);
|
||||
R2(b, c, d, e, a, 34);
|
||||
R2(a, b, c, d, e, 35);
|
||||
R2(e, a, b, c, d, 36);
|
||||
R2(d, e, a, b, c, 37);
|
||||
R2(c, d, e, a, b, 38);
|
||||
R2(b, c, d, e, a, 39);
|
||||
R3(a, b, c, d, e, 40);
|
||||
R3(e, a, b, c, d, 41);
|
||||
R3(d, e, a, b, c, 42);
|
||||
R3(c, d, e, a, b, 43);
|
||||
R3(b, c, d, e, a, 44);
|
||||
R3(a, b, c, d, e, 45);
|
||||
R3(e, a, b, c, d, 46);
|
||||
R3(d, e, a, b, c, 47);
|
||||
R3(c, d, e, a, b, 48);
|
||||
R3(b, c, d, e, a, 49);
|
||||
R3(a, b, c, d, e, 50);
|
||||
R3(e, a, b, c, d, 51);
|
||||
R3(d, e, a, b, c, 52);
|
||||
R3(c, d, e, a, b, 53);
|
||||
R3(b, c, d, e, a, 54);
|
||||
R3(a, b, c, d, e, 55);
|
||||
R3(e, a, b, c, d, 56);
|
||||
R3(d, e, a, b, c, 57);
|
||||
R3(c, d, e, a, b, 58);
|
||||
R3(b, c, d, e, a, 59);
|
||||
R4(a, b, c, d, e, 60);
|
||||
R4(e, a, b, c, d, 61);
|
||||
R4(d, e, a, b, c, 62);
|
||||
R4(c, d, e, a, b, 63);
|
||||
R4(b, c, d, e, a, 64);
|
||||
R4(a, b, c, d, e, 65);
|
||||
R4(e, a, b, c, d, 66);
|
||||
R4(d, e, a, b, c, 67);
|
||||
R4(c, d, e, a, b, 68);
|
||||
R4(b, c, d, e, a, 69);
|
||||
R4(a, b, c, d, e, 70);
|
||||
R4(e, a, b, c, d, 71);
|
||||
R4(d, e, a, b, c, 72);
|
||||
R4(c, d, e, a, b, 73);
|
||||
R4(b, c, d, e, a, 74);
|
||||
R4(a, b, c, d, e, 75);
|
||||
R4(e, a, b, c, d, 76);
|
||||
R4(d, e, a, b, c, 77);
|
||||
R4(c, d, e, a, b, 78);
|
||||
R4(b, c, d, e, a, 79);
|
||||
/* Add the working vars back into context.state[] */
|
||||
state[0] += a;
|
||||
state[1] += b;
|
||||
state[2] += c;
|
||||
state[3] += d;
|
||||
state[4] += e;
|
||||
/* Wipe variables */
|
||||
a = b = c = d = e = 0;
|
||||
#ifdef SHA1HANDSOFF
|
||||
memset(block, '\0', sizeof(block));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/* SHA1Init - Initialize new context */
|
||||
|
||||
void SHA1Init(
|
||||
SHA1_CTX * context
|
||||
)
|
||||
{
|
||||
/* SHA1 initialization constants */
|
||||
context->state[0] = 0x67452301;
|
||||
context->state[1] = 0xEFCDAB89;
|
||||
context->state[2] = 0x98BADCFE;
|
||||
context->state[3] = 0x10325476;
|
||||
context->state[4] = 0xC3D2E1F0;
|
||||
context->count[0] = context->count[1] = 0;
|
||||
}
|
||||
|
||||
|
||||
/* Run your data through this. */
|
||||
|
||||
void SHA1Update(
|
||||
SHA1_CTX * context,
|
||||
const unsigned char *data,
|
||||
uint32_t len
|
||||
)
|
||||
{
|
||||
uint32_t i;
|
||||
|
||||
uint32_t j;
|
||||
|
||||
j = context->count[0];
|
||||
if ((context->count[0] += len << 3) < j)
|
||||
context->count[1]++;
|
||||
context->count[1] += (len >> 29);
|
||||
j = (j >> 3) & 63;
|
||||
if ((j + len) > 63)
|
||||
{
|
||||
memcpy(&context->buffer[j], data, (i = 64 - j));
|
||||
SHA1Transform(context->state, context->buffer);
|
||||
for (; i + 63 < len; i += 64)
|
||||
{
|
||||
SHA1Transform(context->state, &data[i]);
|
||||
}
|
||||
j = 0;
|
||||
}
|
||||
else
|
||||
i = 0;
|
||||
memcpy(&context->buffer[j], &data[i], len - i);
|
||||
}
|
||||
|
||||
|
||||
/* Add padding and return the message digest. */
|
||||
|
||||
void SHA1Final(
|
||||
unsigned char digest[20],
|
||||
SHA1_CTX * context
|
||||
)
|
||||
{
|
||||
unsigned i;
|
||||
|
||||
unsigned char finalcount[8];
|
||||
|
||||
unsigned char c;
|
||||
|
||||
#if 0 /* untested "improvement" by DHR */
|
||||
/* Convert context->count to a sequence of bytes
|
||||
* in finalcount. Second element first, but
|
||||
* big-endian order within element.
|
||||
* But we do it all backwards.
|
||||
*/
|
||||
unsigned char *fcp = &finalcount[8];
|
||||
|
||||
for (i = 0; i < 2; i++)
|
||||
{
|
||||
uint32_t t = context->count[i];
|
||||
|
||||
int j;
|
||||
|
||||
for (j = 0; j < 4; t >>= 8, j++)
|
||||
*--fcp = (unsigned char) t}
|
||||
#else
|
||||
for (i = 0; i < 8; i++)
|
||||
{
|
||||
finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
|
||||
}
|
||||
#endif
|
||||
c = 0200;
|
||||
SHA1Update(context, &c, 1);
|
||||
while ((context->count[0] & 504) != 448)
|
||||
{
|
||||
c = 0000;
|
||||
SHA1Update(context, &c, 1);
|
||||
}
|
||||
SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
|
||||
for (i = 0; i < 20; i++)
|
||||
{
|
||||
digest[i] = (unsigned char)
|
||||
((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
|
||||
}
|
||||
/* Wipe variables */
|
||||
memset(context, '\0', sizeof(*context));
|
||||
memset(&finalcount, '\0', sizeof(finalcount));
|
||||
}
|
||||
|
||||
void SHA1(
|
||||
char *hash_out,
|
||||
const char *str,
|
||||
uint32_t len)
|
||||
{
|
||||
SHA1_CTX ctx;
|
||||
unsigned int ii;
|
||||
|
||||
SHA1Init(&ctx);
|
||||
for (ii=0; ii<len; ii+=1)
|
||||
SHA1Update(&ctx, (const unsigned char*)str + ii, 1);
|
||||
SHA1Final((unsigned char *)hash_out, &ctx);
|
||||
}
|
||||
|
52
examples/gguf-hash/deps/sha1/sha1.h
Normal file
52
examples/gguf-hash/deps/sha1/sha1.h
Normal file
@ -0,0 +1,52 @@
|
||||
#ifndef SHA1_H
|
||||
#define SHA1_H
|
||||
|
||||
/*
|
||||
SHA-1 in C
|
||||
By Steve Reid <steve@edmweb.com>
|
||||
100% Public Domain
|
||||
*/
|
||||
|
||||
#include "stdint.h"
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct
|
||||
{
|
||||
uint32_t state[5];
|
||||
uint32_t count[2];
|
||||
unsigned char buffer[64];
|
||||
} SHA1_CTX;
|
||||
|
||||
void SHA1Transform(
|
||||
uint32_t state[5],
|
||||
const unsigned char buffer[64]
|
||||
);
|
||||
|
||||
void SHA1Init(
|
||||
SHA1_CTX * context
|
||||
);
|
||||
|
||||
void SHA1Update(
|
||||
SHA1_CTX * context,
|
||||
const unsigned char *data,
|
||||
uint32_t len
|
||||
);
|
||||
|
||||
void SHA1Final(
|
||||
unsigned char digest[20],
|
||||
SHA1_CTX * context
|
||||
);
|
||||
|
||||
void SHA1(
|
||||
char *hash_out,
|
||||
const char *str,
|
||||
uint32_t len);
|
||||
|
||||
#if defined(__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* SHA1_H */
|
15
examples/gguf-hash/deps/sha256/package.json
Normal file
15
examples/gguf-hash/deps/sha256/package.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "sha256",
|
||||
"version": "0.0.2",
|
||||
"repo": "jb55/sha256.c",
|
||||
"description": "sha256 in c",
|
||||
"keywords": ["sha256", "sha2"],
|
||||
"src": ["sha256.c", "sha256.h"],
|
||||
"dependencies": {
|
||||
"jb55/rotate-bits.h": "0.1.1"
|
||||
},
|
||||
"development": {
|
||||
"thlorenz/tap.c": "*"
|
||||
}
|
||||
}
|
||||
|
221
examples/gguf-hash/deps/sha256/sha256.c
Normal file
221
examples/gguf-hash/deps/sha256/sha256.c
Normal file
@ -0,0 +1,221 @@
|
||||
/* Crypto/Sha256.c -- SHA-256 Hash
|
||||
2010-06-11 : Igor Pavlov : Public domain
|
||||
This code is based on public domain code from Wei Dai's Crypto++ library. */
|
||||
|
||||
#include "rotate-bits/rotate-bits.h"
|
||||
#include "sha256.h"
|
||||
|
||||
/* define it for speed optimization */
|
||||
#define _SHA256_UNROLL
|
||||
#define _SHA256_UNROLL2
|
||||
|
||||
void
|
||||
sha256_init(sha256_t *p)
|
||||
{
|
||||
p->state[0] = 0x6a09e667;
|
||||
p->state[1] = 0xbb67ae85;
|
||||
p->state[2] = 0x3c6ef372;
|
||||
p->state[3] = 0xa54ff53a;
|
||||
p->state[4] = 0x510e527f;
|
||||
p->state[5] = 0x9b05688c;
|
||||
p->state[6] = 0x1f83d9ab;
|
||||
p->state[7] = 0x5be0cd19;
|
||||
p->count = 0;
|
||||
}
|
||||
|
||||
#define S0(x) (ROTR32(x, 2) ^ ROTR32(x,13) ^ ROTR32(x, 22))
|
||||
#define S1(x) (ROTR32(x, 6) ^ ROTR32(x,11) ^ ROTR32(x, 25))
|
||||
#define s0(x) (ROTR32(x, 7) ^ ROTR32(x,18) ^ (x >> 3))
|
||||
#define s1(x) (ROTR32(x,17) ^ ROTR32(x,19) ^ (x >> 10))
|
||||
|
||||
#define blk0(i) (W[i] = data[i])
|
||||
#define blk2(i) (W[i&15] += s1(W[(i-2)&15]) + W[(i-7)&15] + s0(W[(i-15)&15]))
|
||||
|
||||
#define Ch(x,y,z) (z^(x&(y^z)))
|
||||
#define Maj(x,y,z) ((x&y)|(z&(x|y)))
|
||||
|
||||
#define a(i) T[(0-(i))&7]
|
||||
#define b(i) T[(1-(i))&7]
|
||||
#define c(i) T[(2-(i))&7]
|
||||
#define d(i) T[(3-(i))&7]
|
||||
#define e(i) T[(4-(i))&7]
|
||||
#define f(i) T[(5-(i))&7]
|
||||
#define g(i) T[(6-(i))&7]
|
||||
#define h(i) T[(7-(i))&7]
|
||||
|
||||
|
||||
#ifdef _SHA256_UNROLL2
|
||||
|
||||
#define R(a,b,c,d,e,f,g,h, i) h += S1(e) + Ch(e,f,g) + K[i+j] + (j?blk2(i):blk0(i));\
|
||||
d += h; h += S0(a) + Maj(a, b, c)
|
||||
|
||||
#define RX_8(i) \
|
||||
R(a,b,c,d,e,f,g,h, i); \
|
||||
R(h,a,b,c,d,e,f,g, (i+1)); \
|
||||
R(g,h,a,b,c,d,e,f, (i+2)); \
|
||||
R(f,g,h,a,b,c,d,e, (i+3)); \
|
||||
R(e,f,g,h,a,b,c,d, (i+4)); \
|
||||
R(d,e,f,g,h,a,b,c, (i+5)); \
|
||||
R(c,d,e,f,g,h,a,b, (i+6)); \
|
||||
R(b,c,d,e,f,g,h,a, (i+7))
|
||||
|
||||
#else
|
||||
|
||||
#define R(i) h(i) += S1(e(i)) + Ch(e(i),f(i),g(i)) + K[i+j] + (j?blk2(i):blk0(i));\
|
||||
d(i) += h(i); h(i) += S0(a(i)) + Maj(a(i), b(i), c(i))
|
||||
|
||||
#ifdef _SHA256_UNROLL
|
||||
|
||||
#define RX_8(i) R(i+0); R(i+1); R(i+2); R(i+3); R(i+4); R(i+5); R(i+6); R(i+7);
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
static const uint32_t K[64] = {
|
||||
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
|
||||
0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
|
||||
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
|
||||
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
|
||||
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
|
||||
0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
|
||||
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
|
||||
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
|
||||
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
|
||||
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
|
||||
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
|
||||
0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
|
||||
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
|
||||
0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
|
||||
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
|
||||
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
|
||||
};
|
||||
|
||||
static void
|
||||
sha256_transform(uint32_t *state, const uint32_t *data)
|
||||
{
|
||||
uint32_t W[16] = {0};
|
||||
unsigned j;
|
||||
#ifdef _SHA256_UNROLL2
|
||||
uint32_t a,b,c,d,e,f,g,h;
|
||||
a = state[0];
|
||||
b = state[1];
|
||||
c = state[2];
|
||||
d = state[3];
|
||||
e = state[4];
|
||||
f = state[5];
|
||||
g = state[6];
|
||||
h = state[7];
|
||||
#else
|
||||
uint32_t T[8];
|
||||
for (j = 0; j < 8; j++)
|
||||
T[j] = state[j];
|
||||
#endif
|
||||
|
||||
for (j = 0; j < 64; j += 16)
|
||||
{
|
||||
#if defined(_SHA256_UNROLL) || defined(_SHA256_UNROLL2)
|
||||
RX_8(0); RX_8(8);
|
||||
#else
|
||||
unsigned i;
|
||||
for (i = 0; i < 16; i++) { R(i); }
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef _SHA256_UNROLL2
|
||||
state[0] += a;
|
||||
state[1] += b;
|
||||
state[2] += c;
|
||||
state[3] += d;
|
||||
state[4] += e;
|
||||
state[5] += f;
|
||||
state[6] += g;
|
||||
state[7] += h;
|
||||
#else
|
||||
for (j = 0; j < 8; j++)
|
||||
state[j] += T[j];
|
||||
#endif
|
||||
|
||||
/* Wipe variables */
|
||||
/* memset(W, 0, sizeof(W)); */
|
||||
/* memset(T, 0, sizeof(T)); */
|
||||
}
|
||||
|
||||
#undef S0
|
||||
#undef S1
|
||||
#undef s0
|
||||
#undef s1
|
||||
|
||||
static void
|
||||
sha256_write_byte_block(sha256_t *p)
|
||||
{
|
||||
uint32_t data32[16];
|
||||
unsigned i;
|
||||
for (i = 0; i < 16; i++)
|
||||
data32[i] =
|
||||
((uint32_t)(p->buffer[i * 4 ]) << 24) +
|
||||
((uint32_t)(p->buffer[i * 4 + 1]) << 16) +
|
||||
((uint32_t)(p->buffer[i * 4 + 2]) << 8) +
|
||||
((uint32_t)(p->buffer[i * 4 + 3]));
|
||||
sha256_transform(p->state, data32);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
sha256_hash(unsigned char *buf, const unsigned char *data, size_t size)
|
||||
{
|
||||
sha256_t hash;
|
||||
sha256_init(&hash);
|
||||
sha256_update(&hash, data, size);
|
||||
sha256_final(&hash, buf);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
sha256_update(sha256_t *p, const unsigned char *data, size_t size)
|
||||
{
|
||||
uint32_t curBufferPos = (uint32_t)p->count & 0x3F;
|
||||
while (size > 0)
|
||||
{
|
||||
p->buffer[curBufferPos++] = *data++;
|
||||
p->count++;
|
||||
size--;
|
||||
if (curBufferPos == 64)
|
||||
{
|
||||
curBufferPos = 0;
|
||||
sha256_write_byte_block(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
sha256_final(sha256_t *p, unsigned char *digest)
|
||||
{
|
||||
uint64_t lenInBits = (p->count << 3);
|
||||
uint32_t curBufferPos = (uint32_t)p->count & 0x3F;
|
||||
unsigned i;
|
||||
p->buffer[curBufferPos++] = 0x80;
|
||||
while (curBufferPos != (64 - 8))
|
||||
{
|
||||
curBufferPos &= 0x3F;
|
||||
if (curBufferPos == 0)
|
||||
sha256_write_byte_block(p);
|
||||
p->buffer[curBufferPos++] = 0;
|
||||
}
|
||||
for (i = 0; i < 8; i++)
|
||||
{
|
||||
p->buffer[curBufferPos++] = (unsigned char)(lenInBits >> 56);
|
||||
lenInBits <<= 8;
|
||||
}
|
||||
sha256_write_byte_block(p);
|
||||
|
||||
for (i = 0; i < 8; i++)
|
||||
{
|
||||
*digest++ = (unsigned char)(p->state[i] >> 24);
|
||||
*digest++ = (unsigned char)(p->state[i] >> 16);
|
||||
*digest++ = (unsigned char)(p->state[i] >> 8);
|
||||
*digest++ = (unsigned char)(p->state[i]);
|
||||
}
|
||||
sha256_init(p);
|
||||
}
|
24
examples/gguf-hash/deps/sha256/sha256.h
Normal file
24
examples/gguf-hash/deps/sha256/sha256.h
Normal file
@ -0,0 +1,24 @@
|
||||
/* Sha256.h -- SHA-256 Hash
|
||||
2010-06-11 : Igor Pavlov : Public domain */
|
||||
|
||||
#ifndef __CRYPTO_SHA256_H
|
||||
#define __CRYPTO_SHA256_H
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define SHA256_DIGEST_SIZE 32
|
||||
|
||||
typedef struct sha256_t
|
||||
{
|
||||
uint32_t state[8];
|
||||
uint64_t count;
|
||||
unsigned char buffer[64];
|
||||
} sha256_t;
|
||||
|
||||
void sha256_init(sha256_t *p);
|
||||
void sha256_update(sha256_t *p, const unsigned char *data, size_t size);
|
||||
void sha256_final(sha256_t *p, unsigned char *digest);
|
||||
void sha256_hash(unsigned char *buf, const unsigned char *data, size_t size);
|
||||
|
||||
#endif
|
12
examples/gguf-hash/deps/xxhash/clib.json
Normal file
12
examples/gguf-hash/deps/xxhash/clib.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "xxhash",
|
||||
"version": "0.8.2",
|
||||
"repo": "mofosyne/xxhash",
|
||||
"description": "Extremely fast non-cryptographic hash algorithm",
|
||||
"keywords": ["xxhash", "hashing"],
|
||||
"license": "BSD-2-Clause",
|
||||
"src": [
|
||||
"xxhash.c",
|
||||
"xxhash.h"
|
||||
]
|
||||
}
|
42
examples/gguf-hash/deps/xxhash/xxhash.c
Normal file
42
examples/gguf-hash/deps/xxhash/xxhash.c
Normal file
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* xxHash - Extremely Fast Hash algorithm
|
||||
* Copyright (C) 2012-2023 Yann Collet
|
||||
*
|
||||
* BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* You can contact the author at:
|
||||
* - xxHash homepage: https://www.xxhash.com
|
||||
* - xxHash source repository: https://github.com/Cyan4973/xxHash
|
||||
*/
|
||||
|
||||
/*
|
||||
* xxhash.c instantiates functions defined in xxhash.h
|
||||
*/
|
||||
|
||||
#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */
|
||||
#define XXH_IMPLEMENTATION /* access definitions */
|
||||
|
||||
#include "xxhash.h"
|
7093
examples/gguf-hash/deps/xxhash/xxhash.h
Normal file
7093
examples/gguf-hash/deps/xxhash/xxhash.h
Normal file
File diff suppressed because it is too large
Load Diff
693
examples/gguf-hash/gguf-hash.cpp
Normal file
693
examples/gguf-hash/gguf-hash.cpp
Normal file
@ -0,0 +1,693 @@
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdlib> /* abort() */
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include "xxhash/xxhash.h"
|
||||
#include "sha1/sha1.h"
|
||||
#include "sha256/sha256.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
// uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp')
|
||||
#define UUID_NAMESPACE_LLAMA_CPP "ef001206-dadc-5f6d-a15f-3359e577d4e5"
|
||||
#define UUID_NAMESPACE_LLAMA_CPP_HEX 0xef, 0x00, 0x12, 0x06, 0xda, 0xdc, 0x5f, 0x6d, 0xa1, 0x5f, 0x33, 0x59, 0xe5, 0x77, 0xd4, 0xe5
|
||||
|
||||
|
||||
#define HASH_TYPE_SHA256_STR "sha256"
|
||||
#define HASH_TYPE_SHA1_STR "sha1"
|
||||
#define HASH_TYPE_XXH64_STR "xxh64"
|
||||
#define HASH_TYPE_UUID_STR "uuid"
|
||||
|
||||
|
||||
typedef enum {
|
||||
HASH_EXIT_SUCCESS = 0, // All hash has been generated or validated
|
||||
HASH_EXIT_FAILURE = 1, // Generic Failure
|
||||
HASH_EXIT_MISMATCH = 2, // Hash mismatched during validation
|
||||
HASH_EXIT_MANIFEST_MISSING_ENTRY = 3, // Hash attempted validation but missing entry in manifest
|
||||
HASH_EXIT_MANIFEST_UNKNOWN_HASH = 4, // Manifest is present, but we do not know any hash format within it
|
||||
HASH_EXIT_MANIFEST_FILE_ERROR = 5 // Manifest is either missing or not a known format
|
||||
} hash_exit_code_t;
|
||||
|
||||
|
||||
typedef enum {
|
||||
HASH_MANIFEST_NOT_FOUND,
|
||||
HASH_MANIFEST_MISMATCH,
|
||||
HASH_MANIFEST_OK,
|
||||
} hash_manifest_result_t;
|
||||
|
||||
|
||||
struct hash_params {
|
||||
std::string input;
|
||||
bool xxh64 = false;
|
||||
bool sha1 = false;
|
||||
bool sha256 = false;
|
||||
bool uuid = false;
|
||||
|
||||
bool no_layer = false;
|
||||
|
||||
bool manifest_is_usable = false;
|
||||
std::string manifest_file;
|
||||
};
|
||||
|
||||
struct manifest_check_params {
|
||||
bool xxh64 = false;
|
||||
bool sha1 = false;
|
||||
bool sha256 = false;
|
||||
bool uuid = false;
|
||||
};
|
||||
|
||||
static char const * hash_manifest_result_to_str(hash_manifest_result_t value) {
|
||||
switch (value) {
|
||||
case HASH_MANIFEST_NOT_FOUND: return "Not Found";
|
||||
case HASH_MANIFEST_MISMATCH: return "Mismatch";
|
||||
case HASH_MANIFEST_OK: return "Ok";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static char const * hash_exit_code_to_str(hash_exit_code_t value) {
|
||||
switch (value) {
|
||||
case HASH_EXIT_SUCCESS: return "Success";
|
||||
case HASH_EXIT_FAILURE: return "Failure";
|
||||
case HASH_EXIT_MISMATCH: return "Mismatch";
|
||||
case HASH_EXIT_MANIFEST_MISSING_ENTRY: return "Manifest Missing Entry";
|
||||
case HASH_EXIT_MANIFEST_UNKNOWN_HASH: return "Manifest Unknown Hash";
|
||||
case HASH_EXIT_MANIFEST_FILE_ERROR: return "Manifest File Error";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static void hash_print_usage(const char * executable) {
|
||||
const hash_params default_params;
|
||||
printf("\n");
|
||||
printf("usage: %s [options] GGUF_IN\n", executable);
|
||||
printf("\n");
|
||||
printf("Hash a GGUF file");
|
||||
printf("\n");
|
||||
printf("options:\n");
|
||||
printf(" -h, --help show this help message and exit\n");
|
||||
printf(" --xxh64 use xxh64 hash\n");
|
||||
printf(" --sha1 use sha1 hash\n");
|
||||
printf(" --sha256 use sha256 hash\n");
|
||||
printf(" --all use all hash\n");
|
||||
printf(" --no-layer exclude per layer hash\n");
|
||||
printf(" --uuid generate UUIDv5 ID\n");
|
||||
printf(" -c, --check <manifest> verify against a manifest\n");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
static void hash_params_parse_ex(int argc, const char ** argv, hash_params & params) {
|
||||
std::string arg;
|
||||
bool invalid_param = false;
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
int arg_idx = 1;
|
||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||
arg = argv[arg_idx];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
|
||||
bool arg_found = false;
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
hash_print_usage(argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (arg == "--xxh64") {
|
||||
arg_found = true;
|
||||
params.xxh64 = true;
|
||||
}
|
||||
|
||||
if (arg == "--sha1") {
|
||||
arg_found = true;
|
||||
params.sha1 = true;
|
||||
}
|
||||
|
||||
if (arg == "--uuid") {
|
||||
arg_found = true;
|
||||
params.uuid = true;
|
||||
}
|
||||
|
||||
if (arg == "--sha256") {
|
||||
arg_found = true;
|
||||
params.sha256 = true;
|
||||
}
|
||||
|
||||
if (arg == "--all") {
|
||||
arg_found = true;
|
||||
params.sha256 = true;
|
||||
params.sha1 = true;
|
||||
params.xxh64 = true;
|
||||
}
|
||||
|
||||
if (arg == "--no-layer") {
|
||||
arg_found = true;
|
||||
params.no_layer = true;
|
||||
}
|
||||
|
||||
if (arg == "-c" || arg == "--check") {
|
||||
if (++arg_idx >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
arg_found = true;
|
||||
params.manifest_file = argv[arg_idx];
|
||||
}
|
||||
|
||||
if (!arg_found) {
|
||||
throw std::invalid_argument("error: unknown argument: " + arg);
|
||||
}
|
||||
}
|
||||
|
||||
if (invalid_param) {
|
||||
throw std::invalid_argument("error: invalid parameter for argument:" + arg);
|
||||
}
|
||||
|
||||
if (argc - arg_idx < 1) {
|
||||
throw std::invalid_argument("error: bad arguments");
|
||||
}
|
||||
|
||||
params.input = argv[arg_idx++];
|
||||
}
|
||||
|
||||
static bool hash_params_parse(int argc, const char ** argv, hash_params & params) {
|
||||
bool result = true;
|
||||
try {
|
||||
hash_params_parse_ex(argc, argv, params);
|
||||
}
|
||||
catch (const std::invalid_argument & ex) {
|
||||
fprintf(stderr, "%s\n", ex.what());
|
||||
hash_print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool manifest_type(const std::string & manifest_file, manifest_check_params & manifest_check) {
|
||||
if (manifest_file.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::ifstream file(manifest_file);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string manifest_entry_line;
|
||||
while (getline(file, manifest_entry_line)) {
|
||||
// hash_type_str hash_str tensor_name
|
||||
// e.g. 'xxh64 f66e9cd66a4396a0 test.gguf:tensor_0'
|
||||
std::istringstream line_stream(manifest_entry_line);
|
||||
std::string file_hash_type;
|
||||
if (line_stream >> file_hash_type) {
|
||||
if (file_hash_type == HASH_TYPE_SHA256_STR) {
|
||||
manifest_check.sha256 = true;
|
||||
} else if (file_hash_type == HASH_TYPE_SHA1_STR) {
|
||||
manifest_check.sha1 = true;
|
||||
} else if (file_hash_type == HASH_TYPE_XXH64_STR) {
|
||||
manifest_check.xxh64 = true;
|
||||
} else if (file_hash_type == HASH_TYPE_UUID_STR) {
|
||||
manifest_check.uuid = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static hash_manifest_result_t manifest_verify(const std::string& manifest_file, const std::string& hash_type_str, const std::string& hash_str, const std::string& tensor_name) {
|
||||
if (manifest_file.empty()) {
|
||||
return HASH_MANIFEST_NOT_FOUND;
|
||||
}
|
||||
|
||||
std::ifstream file(manifest_file);
|
||||
if (!file.is_open()) {
|
||||
return HASH_MANIFEST_NOT_FOUND;
|
||||
}
|
||||
|
||||
std::string manifest_entry_line;
|
||||
while (getline(file, manifest_entry_line)) {
|
||||
std::istringstream line_stream(manifest_entry_line);
|
||||
std::string file_hash_type;
|
||||
std::string file_hash;
|
||||
std::string file_tensor_name;
|
||||
if (line_stream >> file_hash_type >> file_hash >> file_tensor_name) {
|
||||
// Line parsed. Check hash validity
|
||||
|
||||
if (file_hash_type != hash_type_str) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (file_tensor_name != tensor_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
return (file_hash == hash_str) ? HASH_MANIFEST_OK : HASH_MANIFEST_MISMATCH;
|
||||
}
|
||||
}
|
||||
|
||||
return HASH_MANIFEST_NOT_FOUND;
|
||||
}
|
||||
|
||||
static void generate_uuidv5(const unsigned char sha1_digest[20], unsigned char uuid[16]) {
|
||||
// Ref: https://www.rfc-editor.org/rfc/rfc9562.html#section-5.5
|
||||
// Assumes that digest was processed correctly with the expected namespace
|
||||
for (int i = 0; i < 16; i++) {
|
||||
uuid[i] = sha1_digest[i];
|
||||
}
|
||||
|
||||
// Set bits corresponding to UUID ver 5
|
||||
uuid[ 6] &= ~(0xF << 4);
|
||||
uuid[ 6] |= (5 << 4);
|
||||
|
||||
// Set bits corresponding to UUID variant 0b10XX
|
||||
uuid[ 8] &= ~(0xc << 4);
|
||||
uuid[ 8] |= (0x8 << 4);
|
||||
}
|
||||
|
||||
static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||
const std::string & fname = hash_params.input;
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ &ctx_data,
|
||||
};
|
||||
|
||||
// xxh64 init
|
||||
XXH64_state_t* xxh64_model_hash_state = NULL;
|
||||
if (hash_params.xxh64) {
|
||||
xxh64_model_hash_state = XXH64_createState();
|
||||
if (xxh64_model_hash_state==NULL) {
|
||||
abort();
|
||||
}
|
||||
|
||||
XXH64_hash_t const seed = 0;
|
||||
if (XXH64_reset(xxh64_model_hash_state, seed) == XXH_ERROR) {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
// sha1 init
|
||||
SHA1_CTX sha1_model_hash_ctx;
|
||||
if (hash_params.sha1) {
|
||||
SHA1Init(&sha1_model_hash_ctx);
|
||||
}
|
||||
|
||||
// sha256 init
|
||||
sha256_t sha256_model_hash_ctx;
|
||||
if (hash_params.sha256) {
|
||||
sha256_init(&sha256_model_hash_ctx);
|
||||
}
|
||||
|
||||
// sha1 for uuid init
|
||||
SHA1_CTX sha1_for_uuid_ctx;
|
||||
if (hash_params.uuid) {
|
||||
unsigned char const uuidv5_namespace[] = {UUID_NAMESPACE_LLAMA_CPP_HEX};
|
||||
SHA1Init(&sha1_for_uuid_ctx);
|
||||
SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)uuidv5_namespace, sizeof(uuidv5_namespace));
|
||||
}
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
|
||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||
bool tensor_layer_in_manifest = false;
|
||||
bool model_in_manifest = false;
|
||||
bool tensor_layer_has_mismatch = false;
|
||||
bool model_has_mismatch = false;
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
const char * name = gguf_get_tensor_name(ctx, i);
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
|
||||
auto n_bytes = ggml_nbytes(cur);
|
||||
auto *raw_data = cur->data;
|
||||
const std::string tensor_layer_name = fname + ":" + name;
|
||||
|
||||
if (hash_params.xxh64) {
|
||||
|
||||
if (!hash_params.no_layer) {
|
||||
// Per Layer Hash
|
||||
XXH64_hash_t hash = XXH64(raw_data, n_bytes, 0);
|
||||
|
||||
char hex_result[17];
|
||||
for (int offset = 0; offset < 8; offset++) {
|
||||
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
||||
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
tensor_layer_in_manifest = true;
|
||||
tensor_layer_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
tensor_layer_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Overall Model Hash
|
||||
if (XXH64_update(xxh64_model_hash_state, raw_data, n_bytes) == XXH_ERROR) abort();
|
||||
}
|
||||
|
||||
if (hash_params.sha1) {
|
||||
|
||||
if (!hash_params.no_layer) {
|
||||
// Per Layer Hash
|
||||
char result[21]; // sha1 outputs 20 bytes
|
||||
SHA1( result, (const char *)raw_data, n_bytes);
|
||||
|
||||
char hex_result[41] = {0};
|
||||
for (int offset = 0; offset < 20; offset++) {
|
||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
tensor_layer_in_manifest = true;
|
||||
tensor_layer_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
tensor_layer_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Overall Model Hash
|
||||
SHA1Update( &sha1_model_hash_ctx, (unsigned char const *)raw_data, n_bytes);
|
||||
}
|
||||
|
||||
if (hash_params.sha256) {
|
||||
|
||||
if (!hash_params.no_layer) {
|
||||
// Per Layer Hash
|
||||
unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes
|
||||
sha256_hash((unsigned char*) result, (const unsigned char *)raw_data, n_bytes);
|
||||
|
||||
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
||||
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
tensor_layer_in_manifest = true;
|
||||
tensor_layer_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
tensor_layer_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Overall Model Hash
|
||||
sha256_update( &sha256_model_hash_ctx, (unsigned char const *)raw_data, n_bytes);
|
||||
}
|
||||
|
||||
if (hash_params.uuid) {
|
||||
SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)raw_data, n_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.xxh64) {
|
||||
XXH64_hash_t const hash = XXH64_digest(xxh64_model_hash_state);
|
||||
|
||||
char hex_result[17];
|
||||
for (int offset = 0; offset < 8; offset++) {
|
||||
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
||||
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.sha1) {
|
||||
unsigned char result[21];
|
||||
SHA1Final(result, &sha1_model_hash_ctx);
|
||||
|
||||
char hex_result[41];
|
||||
for (int offset = 0; offset < 20; offset++) {
|
||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.sha256) {
|
||||
unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes
|
||||
sha256_final( &sha256_model_hash_ctx, result);
|
||||
|
||||
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
||||
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.uuid) {
|
||||
unsigned char result[21];
|
||||
SHA1Final(result, &sha1_for_uuid_ctx);
|
||||
|
||||
unsigned char uuid[16];
|
||||
generate_uuidv5(result, uuid);
|
||||
|
||||
char string_buffer[37] = {0};
|
||||
sprintf(string_buffer, "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
uuid[0], uuid[1], uuid[2], uuid[3],
|
||||
uuid[4], uuid[5], uuid[6], uuid[7],
|
||||
uuid[8], uuid[9], uuid[10], uuid[11],
|
||||
uuid[12], uuid[13], uuid[14], uuid[15]);
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, string_buffer, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
// In hash verification mode
|
||||
|
||||
if (!model_in_manifest) {
|
||||
// model missing in manifest?
|
||||
|
||||
// Check tensor layer...
|
||||
if (!tensor_layer_in_manifest) {
|
||||
// Still missing? Maybe we are reading the wrong manifest.
|
||||
return HASH_EXIT_MANIFEST_MISSING_ENTRY;
|
||||
}
|
||||
|
||||
if (tensor_layer_has_mismatch) {
|
||||
// Per tensor check found error
|
||||
return HASH_EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// All per tensor layer checks passed? Sounds good enough.
|
||||
return HASH_EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
// Overall model check passed, but let's check per layer just in case
|
||||
// If missing, we don't care too much as the overall model checked
|
||||
if (tensor_layer_in_manifest && tensor_layer_has_mismatch) {
|
||||
return HASH_EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (model_has_mismatch) {
|
||||
// model has failed hash somewhere in the model
|
||||
return HASH_EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// All checks appears to be fine
|
||||
return HASH_EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
// In hash generation mode
|
||||
return HASH_EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
hash_params params;
|
||||
manifest_check_params manifest_check;
|
||||
hash_params_parse(argc, argv, params);
|
||||
|
||||
if (!params.manifest_file.empty()) {
|
||||
if (!manifest_type(params.manifest_file, manifest_check)) {
|
||||
printf("ERROR cannot open manifest %s", params.manifest_file.c_str());
|
||||
return HASH_EXIT_MANIFEST_FILE_ERROR;
|
||||
}
|
||||
|
||||
if (!manifest_check.sha256 && !manifest_check.sha1 && !manifest_check.xxh64 && !manifest_check.uuid) {
|
||||
printf("ERROR manifest does not have any known hash format in %s", params.manifest_file.c_str());
|
||||
return HASH_EXIT_MANIFEST_UNKNOWN_HASH;
|
||||
}
|
||||
|
||||
printf("manifest %s", params.manifest_file.c_str());
|
||||
|
||||
if (manifest_check.sha256) {
|
||||
printf(" sha256");
|
||||
}
|
||||
|
||||
if (manifest_check.sha1) {
|
||||
printf(" sha1");
|
||||
}
|
||||
|
||||
if (manifest_check.xxh64) {
|
||||
printf(" xxh64");
|
||||
}
|
||||
|
||||
if (manifest_check.uuid) {
|
||||
printf(" uuid");
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
// Autoselect the highest security hash if manifest is provided but
|
||||
// the user has not specifically defined the hash they care about
|
||||
if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) {
|
||||
// User has not selected a specific value, pick most secure hash
|
||||
if (manifest_check.sha256) {
|
||||
params.sha256 = true;
|
||||
} else if (manifest_check.sha1) {
|
||||
params.sha1 = true;
|
||||
} else if (manifest_check.xxh64) {
|
||||
params.xxh64 = true;
|
||||
} else if (manifest_check.uuid) {
|
||||
params.uuid = true;
|
||||
}
|
||||
}
|
||||
|
||||
params.manifest_is_usable = true;
|
||||
}
|
||||
|
||||
// By default if no swich argument provided, assume xxh64
|
||||
if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) {
|
||||
params.xxh64 = true;
|
||||
}
|
||||
|
||||
hash_exit_code_t exit_code = gguf_hash(params);
|
||||
|
||||
if (params.manifest_is_usable) {
|
||||
printf("\nVerification results for %s - %s\n", params.manifest_file.c_str(), hash_exit_code_to_str(exit_code));
|
||||
}
|
||||
|
||||
return exit_code;
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
-r ../../requirements/requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pillow~=10.2.0
|
||||
torch~=2.2.1
|
||||
|
@ -366,7 +366,8 @@ Notice that each `probs` is an array of length `n_probs`.
|
||||
"assistant_name": "",
|
||||
"user_name": "",
|
||||
"default_generation_settings": { ... },
|
||||
"total_slots": 1
|
||||
"total_slots": 1,
|
||||
"chat_template": ""
|
||||
}
|
||||
```
|
||||
|
||||
@ -374,6 +375,7 @@ Notice that each `probs` is an array of length `n_probs`.
|
||||
- `user_name` - the required anti-prompt to generate the prompt in case you have specified a system prompt for all slots.
|
||||
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
|
||||
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
|
||||
- `chat_template` - the model's original Jinja2 prompt template
|
||||
|
||||
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
|
||||
|
||||
|
@ -2605,7 +2605,7 @@ int main(int argc, char ** argv) {
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
@ -2967,11 +2967,20 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
||||
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json data = {
|
||||
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params.n_parallel }
|
||||
{ "total_slots", ctx_server.params.n_parallel },
|
||||
{ "chat_template", curr_tmpl.c_str() }
|
||||
};
|
||||
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
|
@ -120,7 +120,6 @@ class Keys:
|
||||
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
|
||||
EOT_ID = "tokenizer.ggml.eot_token_id"
|
||||
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
#
|
||||
@ -163,6 +162,7 @@ class MODEL_ARCH(IntEnum):
|
||||
OPENELM = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
CHATGLM = auto()
|
||||
BITNET = auto()
|
||||
T5 = auto()
|
||||
JAIS = auto()
|
||||
@ -289,6 +289,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.OPENELM: "openelm",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
@ -924,6 +925,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM : [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.BITNET: [
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
@ -1020,6 +1033,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
@ -24,6 +24,7 @@ class TensorNameMap:
|
||||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
"transformer.in_out_embed", # Grok
|
||||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
),
|
||||
@ -55,6 +56,7 @@ class TensorNameMap:
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
"output_layer", # chatglm
|
||||
),
|
||||
|
||||
# Output norm
|
||||
@ -71,12 +73,14 @@ class TensorNameMap:
|
||||
"model.norm_f", # mamba-qbert
|
||||
"backbone.norm_f", # mamba
|
||||
"transformer.rms_norm", # Grok
|
||||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"rope.freqs", # llama-pth
|
||||
"rotary_pos_emb.inv_freq", # chatglm
|
||||
),
|
||||
}
|
||||
|
||||
@ -101,6 +105,7 @@ class TensorNameMap:
|
||||
"backbone.layers.{bid}.norm", # mamba
|
||||
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
),
|
||||
|
||||
@ -124,6 +129,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
),
|
||||
|
||||
@ -135,7 +141,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@ -147,7 +153,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.attn.k", # refact
|
||||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@ -182,6 +188,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
),
|
||||
|
||||
@ -218,6 +225,7 @@ class TensorNameMap:
|
||||
"h.{bid}.ln_2", # gpt2
|
||||
"model.layers.{bid}.ffn_norm", # internlm2
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
),
|
||||
|
||||
@ -268,6 +276,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@ -337,6 +346,7 @@ class TensorNameMap:
|
||||
"transformer.layers.{bid}.ffn.proj_2", # openelm
|
||||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
91
gguf-py/scripts/gguf_hash.py
Executable file
91
gguf-py/scripts/gguf_hash.py
Executable file
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import hashlib
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
|
||||
logger = logging.getLogger("gguf-hash")
|
||||
|
||||
# UUID_NAMESPACE_LLAMA_CPP = uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp')
|
||||
UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5')
|
||||
|
||||
|
||||
# For more information about what field.parts and field.data represent,
|
||||
# please see the comments in the modify_gguf.py example.
|
||||
def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
|
||||
sha1 = hashlib.sha1()
|
||||
uuidv5_sha1 = hashlib.sha1()
|
||||
uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes)
|
||||
|
||||
# Total Weight Calculation For Progress Bar
|
||||
total_weights = 0
|
||||
for n, tensor in enumerate(reader.tensors, 1):
|
||||
|
||||
# We don't need these
|
||||
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Calculate Tensor Volume
|
||||
sum_weights_in_tensor = 1
|
||||
for dim in tensor.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
total_weights += sum_weights_in_tensor
|
||||
|
||||
# Hash Progress Bar
|
||||
bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar)
|
||||
|
||||
# Hashing Process
|
||||
for n, tensor in enumerate(reader.tensors, 1):
|
||||
|
||||
# We don't need these
|
||||
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Progressbar
|
||||
sum_weights_in_tensor = 1
|
||||
for dim in tensor.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
bar.update(sum_weights_in_tensor)
|
||||
|
||||
sha1_layer = hashlib.sha1()
|
||||
sha1_layer.update(tensor.data)
|
||||
sha1.update(tensor.data)
|
||||
uuidv5_sha1.update(tensor.data)
|
||||
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
||||
|
||||
# Flush Hash Progress Bar
|
||||
bar.close()
|
||||
|
||||
# Display Hash Output
|
||||
print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100
|
||||
print("UUIDv5 {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
|
||||
parser.add_argument("model", type=str, help="GGUF format model filename")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
parser.add_argument("--progressbar", action="store_true", help="enable progressbar")
|
||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
reader = GGUFReader(args.model, 'r')
|
||||
gguf_hash(reader, args.model, not args.progressbar)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -88,8 +88,10 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 17,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
@ -1,2 +1,3 @@
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch~=2.2.1
|
||||
|
@ -1,2 +1,3 @@
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch~=2.2.1
|
||||
|
280
src/llama.cpp
280
src/llama.cpp
@ -229,6 +229,7 @@ enum llm_arch {
|
||||
LLM_ARCH_OPENELM,
|
||||
LLM_ARCH_ARCTIC,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_JAIS,
|
||||
@ -272,6 +273,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_OPENELM, "openelm" },
|
||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
@ -1205,6 +1207,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CHATGLM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BITNET,
|
||||
{
|
||||
@ -2087,9 +2104,11 @@ enum e_model {
|
||||
MODEL_2_8B,
|
||||
MODEL_3B,
|
||||
MODEL_4B,
|
||||
MODEL_6B,
|
||||
MODEL_6_9B,
|
||||
MODEL_7B,
|
||||
MODEL_8B,
|
||||
MODEL_9B,
|
||||
MODEL_11B,
|
||||
MODEL_12B,
|
||||
MODEL_13B,
|
||||
@ -2115,7 +2134,6 @@ enum e_model {
|
||||
MODEL_16x12B,
|
||||
MODEL_10B_128x3_66B,
|
||||
MODEL_57B_A14B,
|
||||
MODEL_9B,
|
||||
MODEL_27B,
|
||||
};
|
||||
|
||||
@ -3311,6 +3329,8 @@ static void llama_kv_cache_seq_add(
|
||||
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
// If there is no range then return early to avoid looping over the cache.
|
||||
if (p0 == p1) return;
|
||||
|
||||
if (cache.recurrent) {
|
||||
// for Mamba-like models, only the pos needs to be shifted
|
||||
@ -3355,6 +3375,8 @@ static void llama_kv_cache_seq_div(
|
||||
int d) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
// If there is no range then return early to avoid looping over the cache.
|
||||
if (p0 == p1) return;
|
||||
|
||||
if (cache.recurrent) {
|
||||
// for Mamba-like models, only the pos needs to be changed
|
||||
@ -4537,9 +4559,11 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_2_8B: return "2.8B";
|
||||
case MODEL_3B: return "3B";
|
||||
case MODEL_4B: return "4B";
|
||||
case MODEL_6B: return "6B";
|
||||
case MODEL_6_9B: return "6.9B";
|
||||
case MODEL_7B: return "7B";
|
||||
case MODEL_8B: return "8B";
|
||||
case MODEL_9B: return "9B";
|
||||
case MODEL_11B: return "11B";
|
||||
case MODEL_12B: return "12B";
|
||||
case MODEL_13B: return "13B";
|
||||
@ -4565,7 +4589,6 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_16x12B: return "16x12B";
|
||||
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
|
||||
case MODEL_57B_A14B: return "57B.A14B";
|
||||
case MODEL_9B: return "9B";
|
||||
case MODEL_27B: return "27B";
|
||||
default: return "?B";
|
||||
}
|
||||
@ -4672,16 +4695,6 @@ static void llm_load_hparams(
|
||||
|
||||
// non-transformer models do not have attention heads
|
||||
if (hparams.n_head() > 0) {
|
||||
// sanity check for n_rot (optional)
|
||||
hparams.n_rot = hparams.n_embd / hparams.n_head();
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
|
||||
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
|
||||
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
|
||||
}
|
||||
}
|
||||
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
|
||||
// gpt-j n_rot = rotary_dim
|
||||
|
||||
@ -4690,6 +4703,17 @@ static void llm_load_hparams(
|
||||
|
||||
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
hparams.n_rot = hparams.n_embd_head_k;
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
|
||||
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
|
||||
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hparams.n_rot = 0;
|
||||
hparams.n_embd_head_k = 0;
|
||||
@ -5170,6 +5194,15 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 28: model.type = e_model::MODEL_6B; break;
|
||||
case 40: model.type = e_model::MODEL_9B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@ -5302,9 +5335,7 @@ static void llm_load_vocab(
|
||||
if (merges_keyidx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||
}
|
||||
|
||||
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||
@ -5447,6 +5478,10 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "poro-chat") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "chatglm-bpe") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
||||
vocab.special_bos_id = -1;
|
||||
} else if (
|
||||
tokenizer_pre == "viking") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
|
||||
@ -5571,7 +5606,6 @@ static void llm_load_vocab(
|
||||
vocab.special_eot_id = 107;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
|
||||
} catch (const std::exception & e) {
|
||||
@ -7479,6 +7513,36 @@ static bool llm_load_tensors(
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
|
||||
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)});
|
||||
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2});
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -7703,6 +7767,7 @@ enum llm_ffn_op_type {
|
||||
LLM_FFN_GELU,
|
||||
LLM_FFN_RELU,
|
||||
LLM_FFN_RELU_SQR,
|
||||
LLM_FFN_SWIGLU,
|
||||
};
|
||||
|
||||
enum llm_ffn_gate_type {
|
||||
@ -7934,6 +7999,19 @@ static struct ggml_tensor * llm_build_ffn(
|
||||
cur = ggml_sqr(ctx, cur);
|
||||
cb(cur, "ffn_sqr(relu)", il);
|
||||
} break;
|
||||
case LLM_FFN_SWIGLU:
|
||||
{
|
||||
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
x0 = ggml_silu(ctx, x0);
|
||||
cb(cur, "ffn_silu", il);
|
||||
|
||||
cur = ggml_mul(ctx, x0, x1);
|
||||
cb(cur, "ffn_mul", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
if (type_gate == LLM_FFN_PAR) {
|
||||
@ -10785,19 +10863,12 @@ struct llm_build_context {
|
||||
// special-case: the up and gate tensors are merged into a single tensor
|
||||
// TOOD: support into llm_build_ffn
|
||||
{
|
||||
struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||
cb(up, "ffn_up", il);
|
||||
|
||||
auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0));
|
||||
auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2));
|
||||
|
||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, g));
|
||||
cb(y, "ffn_gate", il);
|
||||
|
||||
auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y);
|
||||
cb(down, "ffn_down", il);
|
||||
|
||||
cur = down;
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
@ -11567,7 +11638,7 @@ struct llm_build_context {
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
@ -11576,7 +11647,7 @@ struct llm_build_context {
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
@ -11680,7 +11751,7 @@ struct llm_build_context {
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
@ -11689,7 +11760,7 @@ struct llm_build_context {
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
@ -13489,6 +13560,120 @@ struct llm_build_context {
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_chatglm() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = nullptr;
|
||||
struct ggml_tensor * Kcur = nullptr;
|
||||
struct ggml_tensor * Vcur = nullptr;
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// Add the input
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// FF
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
}
|
||||
|
||||
inpL = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(inpL, "l_out", il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
@ -13720,6 +13905,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_deepseek2();
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
result = llm.build_chatglm();
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
result = llm.build_bitnet();
|
||||
@ -15335,6 +15524,11 @@ struct llm_tokenizer_bpe {
|
||||
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
|
||||
regex_exprs = {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_VIKING:
|
||||
regex_exprs = {
|
||||
"\\p{N}",
|
||||
@ -16236,7 +16430,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||
if (add_special) {
|
||||
tokenizer.append_bos(output);
|
||||
}
|
||||
|
||||
for (const auto & fragment : fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
@ -19130,6 +19323,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_OLMO:
|
||||
case LLM_ARCH_ARCTIC:
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
@ -20864,7 +21058,6 @@ int32_t llama_tokenize(
|
||||
bool add_special,
|
||||
bool parse_special) {
|
||||
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
|
||||
|
||||
if (n_tokens_max < (int) res.size()) {
|
||||
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||
return -((int) res.size());
|
||||
@ -21283,6 +21476,25 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||
}
|
||||
} else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
|
||||
// chatglm3-6b
|
||||
ss << "[gMASK]" << "sop";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n " << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == "chaglm4" || tmpl_contains("[gMASK]<sop>")) {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
for (auto message : chat) {
|
||||
|
@ -58,6 +58,10 @@ int main(void) {
|
||||
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
//Phi-3-vision
|
||||
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
|
||||
// ChatGLM3
|
||||
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||
// ChatGLM4
|
||||
u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
||||
// DeepSeek-V2
|
||||
@ -98,6 +102,10 @@ int main(void) {
|
||||
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||
//Phi-3-vision
|
||||
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||
// ChatGLM3
|
||||
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
|
||||
// ChatGLM4
|
||||
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
||||
// DeepSeek-V2
|
||||
|
Loading…
Reference in New Issue
Block a user