mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
convert : fix Tensor type annotations
This commit is contained in:
parent
0d052cbe39
commit
6a9d3c0911
15
convert.py
15
convert.py
@ -33,7 +33,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
|
|
||||||
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
||||||
faulthandler.register(signal.SIGUSR1)
|
faulthandler.register(signal.SIGUSR1)
|
||||||
@ -646,16 +646,17 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
|||||||
|
|
||||||
|
|
||||||
class Tensor(ABC):
|
class Tensor(ABC):
|
||||||
|
ndarray: NDArray
|
||||||
data_type: DataType
|
data_type: DataType
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def astype(self, data_type: DataType) -> Tensor: ...
|
def astype(self, data_type: DataType) -> Self: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
|
def permute(self, n_head: int, n_head_kv: int) -> Self: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
|
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def part(self, n_part: int) -> UnquantizedTensor: ...
|
def part(self, n_part: int) -> Self: ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_ggml(self) -> GGMLCompatibleTensor: ...
|
def to_ggml(self) -> GGMLCompatibleTensor: ...
|
||||||
|
|
||||||
@ -672,13 +673,13 @@ class UnquantizedTensor(Tensor):
|
|||||||
self.ndarray = ndarray
|
self.ndarray = ndarray
|
||||||
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
|
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
|
||||||
|
|
||||||
def astype(self, data_type: DataType) -> Tensor:
|
def astype(self, data_type: DataType) -> UnquantizedTensor:
|
||||||
dtype = data_type.dtype
|
dtype = data_type.dtype
|
||||||
if self.data_type == DT_BF16:
|
if self.data_type == DT_BF16:
|
||||||
self.ndarray = bf16_to_fp32(self.ndarray)
|
self.ndarray = bf16_to_fp32(self.ndarray)
|
||||||
return UnquantizedTensor(self.ndarray.astype(dtype))
|
return UnquantizedTensor(self.ndarray.astype(dtype))
|
||||||
|
|
||||||
def to_ggml(self) -> UnquantizedTensor:
|
def to_ggml(self) -> Self:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
|
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
|
||||||
|
Loading…
Reference in New Issue
Block a user