only support GGUFv3

This commit is contained in:
ddh0 2024-06-27 15:25:11 -05:00
parent 8a6784a7d3
commit 65c4c15314

View File

@ -23,7 +23,7 @@ class GGUFValueType(IntEnum):
# the GGUF format versions that this module supports # the GGUF format versions that this module supports
SUPPORTED_GGUF_VERSIONS = [2, 3] SUPPORTED_GGUF_VERSIONS = [3]
# GGUF only supports execution on little or big endian machines # GGUF only supports execution on little or big endian machines
if sys.byteorder not in ['little', 'big']: if sys.byteorder not in ['little', 'big']:
@ -76,7 +76,10 @@ def get_single(
string_length = unpack(GGUFValueType.UINT64, file=file) string_length = unpack(GGUFValueType.UINT64, file=file)
value = file.read(string_length) value = file.read(string_length)
# officially, strings that cannot be decoded into utf-8 are invalid # officially, strings that cannot be decoded into utf-8 are invalid
value = value.decode("utf-8") try:
value = value.decode("utf-8")
except:
pass
else: else:
value = unpack(value_type, file=file) value = unpack(value_type, file=file)
return value return value
@ -113,20 +116,10 @@ def load_metadata(
) )
tensor_count = unpack(GGUFValueType.UINT64, file=file) tensor_count = unpack(GGUFValueType.UINT64, file=file)
if version == 3: metadata_kv_count = unpack(GGUFValueType.UINT64, file=file)
metadata_kv_count = unpack(GGUFValueType.UINT64, file=file)
elif version == 2:
metadata_kv_count = unpack(GGUFValueType.UINT32, file=file)
for _ in range(metadata_kv_count): for _ in range(metadata_kv_count):
if version == 3: key_length = unpack(GGUFValueType.UINT64, file=file)
key_length = unpack(GGUFValueType.UINT64, file=file)
elif version == 2:
key_length = 0
while key_length == 0:
# seek until next key is found
key_length = unpack(GGUFValueType.UINT32, file=file)
file.read(4) # 4 byte offset for GGUFv2
key = file.read(key_length) key = file.read(key_length)
value_type = GGUFValueType( value_type = GGUFValueType(
unpack(GGUFValueType.UINT32, file=file) unpack(GGUFValueType.UINT32, file=file)
@ -136,11 +129,7 @@ def load_metadata(
unpack(GGUFValueType.UINT32, file=file) unpack(GGUFValueType.UINT32, file=file)
) )
# array_length is the number of items in the array # array_length is the number of items in the array
if version == 3: array_length = unpack(GGUFValueType.UINT64, file=file)
array_length = unpack(GGUFValueType.UINT64, file=file)
elif version == 2:
array_length = unpack(GGUFValueType.UINT32, file=file)
file.read(4) # 4 byte offset for GGUFv2
array = [ array = [
get_single( get_single(
array_value_type, array_value_type,