diff --git a/convert.py b/convert.py index 244eb7582..b0dd6628a 100755 --- a/convert.py +++ b/convert.py @@ -33,7 +33,7 @@ if 'NO_LOCAL_GGUF' not in os.environ: import gguf if TYPE_CHECKING: - from typing import TypeAlias + from typing_extensions import Self, TypeAlias if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): faulthandler.register(signal.SIGUSR1) @@ -646,16 +646,17 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: class Tensor(ABC): + ndarray: NDArray data_type: DataType @abstractmethod - def astype(self, data_type: DataType) -> Tensor: ... + def astype(self, data_type: DataType) -> Self: ... @abstractmethod - def permute(self, n_head: int, n_head_kv: int) -> Tensor: ... + def permute(self, n_head: int, n_head_kv: int) -> Self: ... @abstractmethod - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ... + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ... @abstractmethod - def part(self, n_part: int) -> UnquantizedTensor: ... + def part(self, n_part: int) -> Self: ... @abstractmethod def to_ggml(self) -> GGMLCompatibleTensor: ... @@ -672,13 +673,13 @@ class UnquantizedTensor(Tensor): self.ndarray = ndarray self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] - def astype(self, data_type: DataType) -> Tensor: + def astype(self, data_type: DataType) -> UnquantizedTensor: dtype = data_type.dtype if self.data_type == DT_BF16: self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) - def to_ggml(self) -> UnquantizedTensor: + def to_ggml(self) -> Self: return self def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: