Update Imports and Add Notes for Future Reference

- Updated import statements in `convert.py`.
- Added import for `AutoTokenizer` from `transformers` module.
- Added conditional import for `gguf` from the local directory.
- Added comments and notes for future reference.

Additional Notes:

- Noted removal of a redundant `TypeAlias` import.
- Noted the removal of a `gguf` debug statement.
- Commented on the presence of `ARCH` and `NDArray` definitions.
- Commented on cleaning up and refactoring data type definitions.
This commit is contained in:
teleprint-me 2024-01-07 18:51:51 -05:00
parent c75ca5d96f
commit b69021ef7f
No known key found for this signature in database
GPG Key ID: B0D11345E65C4D48

View File

@ -17,29 +17,58 @@ import signal
import struct
import sys
import time
import warnings
import zipfile
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from argparse import ArgumentParser
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Iterable,
Literal,
Optional,
Tuple,
TypeVar,
)
import numpy as np
from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
try:
from transformers import AutoTokenizer
except ModuleNotFoundError as e:
warnings.warn(f"Could not import AutoTokenizer from transformers: {e}")
if TYPE_CHECKING:
from typing import TypeAlias
# If NO_LOCAL_GGUF is not set, try to import gguf from the local gguf-py directory
if "NO_LOCAL_GGUF" not in os.environ:
# Use absolute path to the gguf-py directory
gguf_py_dir = str(Path(__file__).resolve().parent / "gguf-py")
print(gguf_py_dir) # NOTE: Remove this once path is verified after changes are completed
if gguf_py_dir not in sys.path:
sys.path.insert(1, gguf_py_dir)
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
# Import gguf module
try:
import gguf
except ModuleNotFoundError as e:
print(f"Could not import gguf: {e}")
sys.exit(1)
if TYPE_CHECKING: # NOTE: This isn't necessary.
from typing import TypeAlias # This can technically be omitted.
if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"):
faulthandler.register(signal.SIGUSR1)
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
# NOTE: n-dimensional arrays should be directly referenced
NDArray: TypeAlias = "np.ndarray[Any, Any]"
# Why is this here? LLAMA and GPT are technically the only compatible ARCHs.
ARCH = gguf.MODEL_ARCH.LLAMA
DEFAULT_CONCURRENCY = 8
@ -48,7 +77,7 @@ DEFAULT_CONCURRENCY = 8
# data types
#
# TODO: Clean up and refactor data types
@dataclass(frozen=True)
class DataType:
name: str