py : linting with mypy and isort

This commit is contained in:
Jared Van Bortel 2024-01-19 12:38:18 -05:00
parent ffdd051ab5
commit 4a3bc1522e
5 changed files with 15 additions and 12 deletions

View File

@ -10,7 +10,7 @@ import re
import sys import sys
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
import numpy as np import numpy as np
import torch import torch
@ -487,6 +487,7 @@ class MPTModel(Model):
# map tensor names # map tensor names
if "scales" in name: if "scales" in name:
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales")) new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
if new_name is not None:
new_name = new_name.replace("scales", "act.scales") new_name = new_name.replace("scales", "act.scales")
else: else:
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
@ -904,7 +905,7 @@ class QwenModel(Model):
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
@staticmethod @staticmethod
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]: def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
parts = [bytes([b]) for b in token] parts = [bytes([b]) for b in token]
while True: while True:
min_idx = None min_idx = None
@ -1285,7 +1286,7 @@ def main() -> None:
if args.awq_path: if args.awq_path:
sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
from awq.apply_awq import add_scale_weights from awq.apply_awq import add_scale_weights # type: ignore[import-not-found]
tmp_model_path = args.model / "weighted_model" tmp_model_path = args.model / "weighted_model"
dir_model = tmp_model_path dir_model = tmp_model_path
if tmp_model_path.is_dir(): if tmp_model_path.is_dir():

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import os
import struct import struct
import sys import sys
from enum import IntEnum from enum import IntEnum
@ -9,7 +10,6 @@ from pathlib import Path
import numpy as np import numpy as np
import os
if 'NO_LOCAL_GGUF' not in os.environ: if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf import gguf

View File

@ -5,17 +5,16 @@ import json
import os import os
import struct import struct
import sys import sys
from pathlib import Path
from typing import Any, BinaryIO, Sequence from typing import Any, BinaryIO, Sequence
import numpy as np import numpy as np
import torch import torch
from pathlib import Path
if 'NO_LOCAL_GGUF' not in os.environ: if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf import gguf
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1} NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import torch
import os
from pprint import pprint
import sys
import argparse import argparse
import os
import sys
from pathlib import Path from pathlib import Path
from pprint import pprint
import torch
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ: if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf import gguf
@ -69,7 +71,7 @@ def main():
persimmon_model = torch.load(args.ckpt_path) persimmon_model = torch.load(args.ckpt_path)
hparams = persimmon_model['args'] hparams = persimmon_model['args']
pprint(hparams) pprint(hparams)
tensors = {} tensors: dict[str, torch.Tensor] = {}
_flatten_dict(persimmon_model['model'], tensors, None) _flatten_dict(persimmon_model['model'], tensors, None)
arch = gguf.MODEL_ARCH.PERSIMMON arch = gguf.MODEL_ARCH.PERSIMMON

View File

@ -4,3 +4,4 @@ allow_untyped_calls = true
allow_untyped_defs = true allow_untyped_defs = true
allow_incomplete_defs = true allow_incomplete_defs = true
disable_error_code = import-untyped disable_error_code = import-untyped
warn_return_any = false