llama : bump max layers from 256 to 512 (#8530)

* llama : bump max layers from 256 to 512

* llama : replace asserts with exceptions
This commit is contained in:
Georgi Gerganov 2024-07-19 16:50:47 +03:00 committed by GitHub
parent be0cfb4175
commit d197545530
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 6 deletions

View File

@ -40,7 +40,7 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 6 #define LLAMA_SESSION_VERSION 7
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1 #define LLAMA_STATE_SEQ_VERSION 1

View File

@ -114,7 +114,7 @@
// bump if necessary // bump if necessary
#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_LAYERS 256 #define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 #define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
// //
@ -4007,7 +4007,9 @@ struct llama_model_loader {
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
} }
GGML_ASSERT(arr_info.length <= N_MAX); if (arr_info.length > N_MAX) {
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
}
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
@ -4043,8 +4045,6 @@ struct llama_model_loader {
// get array of n <= N_MAX elements, or a single element repeated n times // get array of n <= N_MAX elements, or a single element repeated n times
template<typename T, size_t N_MAX> template<typename T, size_t N_MAX>
bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true) { bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true) {
GGML_ASSERT(n <= N_MAX);
const int kid = gguf_find_key(meta, key.c_str()); const int kid = gguf_find_key(meta, key.c_str());
if (kid < 0) { if (kid < 0) {
@ -4054,6 +4054,10 @@ struct llama_model_loader {
return false; return false;
} }
if (n > N_MAX) {
throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
}
if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) { if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
struct GGUFMeta::ArrayInfo arr_info = struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid); GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
@ -19920,7 +19924,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
); );
// on session change it is very likely that the state size has changed - so we need to update this function // on session change it is very likely that the state size has changed - so we need to update this function
static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
return s_total; return s_total;
} }