diff --git a/convert.py b/convert.py index c3f3fc0a1..b73d90344 100755 --- a/convert.py +++ b/convert.py @@ -17,29 +17,58 @@ import signal import struct import sys import time +import warnings import zipfile from abc import ABCMeta, abstractmethod -from collections import OrderedDict +from argparse import ArgumentParser from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast +from typing import ( + IO, + TYPE_CHECKING, + Any, + Callable, + Iterable, + Literal, + Optional, + Tuple, + TypeVar, +) import numpy as np from sentencepiece import SentencePieceProcessor -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) -import gguf +try: + from transformers import AutoTokenizer +except ModuleNotFoundError as e: + warnings.warn(f"Could not import AutoTokenizer from transformers: {e}") -if TYPE_CHECKING: - from typing import TypeAlias +# If NO_LOCAL_GGUF is not set, try to import gguf from the local gguf-py directory +if "NO_LOCAL_GGUF" not in os.environ: + # Use absolute path to the gguf-py directory + gguf_py_dir = str(Path(__file__).resolve().parent / "gguf-py") + print(gguf_py_dir) # NOTE: Remove this once path is verified after changes are completed + if gguf_py_dir not in sys.path: + sys.path.insert(1, gguf_py_dir) -if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): +# Import gguf module +try: + import gguf +except ModuleNotFoundError as e: + print(f"Could not import gguf: {e}") + sys.exit(1) + +if TYPE_CHECKING: # NOTE: This isn't necessary. + from typing import TypeAlias # This can technically be omitted. + +if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"): faulthandler.register(signal.SIGUSR1) -NDArray: TypeAlias = 'np.ndarray[Any, Any]' +# NOTE: n-dimensional arrays should be directly referenced +NDArray: TypeAlias = "np.ndarray[Any, Any]" +# Why is this here? LLAMA and GPT are technically the only compatible ARCHs. ARCH = gguf.MODEL_ARCH.LLAMA DEFAULT_CONCURRENCY = 8 @@ -48,7 +77,7 @@ DEFAULT_CONCURRENCY = 8 # data types # - +# TODO: Clean up and refactor data types @dataclass(frozen=True) class DataType: name: str