mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
convert scripts : fix python 3.8 compatibility
This commit is contained in:
parent
6a9d3c0911
commit
909f6be291
@ -2311,7 +2311,8 @@ class MambaModel(Model):
|
|||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
# if f16 desired, convert big float32 2-dim weight tensors to float16
|
# if f16 desired, convert big float32 2-dim weight tensors to float16
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else ""
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
Loading…
Reference in New Issue
Block a user