convert_hf : fix memory leak in lazy MoE conversion

The '_lazy' queue was sometimes self-referential,
which caused reference cycles of objects old enough
to avoid garbage collection until potential memory exhaustion.
This commit is contained in:
Francis Couture-Harpin 2024-07-15 21:09:04 -04:00
parent 2a49a68d70
commit b971122eb1
2 changed files with 23 additions and 51 deletions

View File

@ -3456,20 +3456,19 @@ class LazyTorchTensor(gguf.LazyBase):
dtype = self._dtype_map[self.dtype] dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor( return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
lazy=self._lazy,
args=(self,), args=(self,),
func=(lambda s: s[0].numpy()) func=(lambda s: s.numpy())
) )
@classmethod @classmethod
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor: def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta") return torch.empty(size=shape, dtype=dtype, device="meta")
@classmethod @classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor: def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()] dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape = st_slice.get_shape() shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[0][:]) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy) return cast(torch.Tensor, lazy)
@classmethod @classmethod
@ -3482,7 +3481,7 @@ class LazyTorchTensor(gguf.LazyBase):
if func is torch.Tensor.numpy: if func is torch.Tensor.numpy:
return args[0].numpy() return args[0].numpy()
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs) return cls._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:

View File

@ -3,7 +3,6 @@ from abc import ABC, ABCMeta, abstractmethod
import logging import logging
from typing import Any, Callable from typing import Any, Callable
from collections import deque
import numpy as np import numpy as np
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type _tensor_type: type
_meta: Any _meta: Any
_data: Any | None _data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple _args: tuple
_func: Callable[[tuple], Any] | None _kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__() super().__init__()
self._meta = meta self._meta = meta
self._data = data self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func self._func = func
assert self._func is not None or self._data is not None assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__: if "_tensor_type" not in cls.__dict__:
@ -117,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
args = ((use_self,) if use_self is not None else ()) + args args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
if isinstance(meta_noop, bool) and not meta_noop: if isinstance(meta_noop, bool) and not meta_noop:
try: try:
@ -140,23 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type): if isinstance(res, cls._tensor_type):
class CollectSharedLazy: return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
# emulating a static variable
shared_lazy: None | deque[LazyBase] = None
@staticmethod
def collect_replace(t: LazyBase):
if CollectSharedLazy.shared_lazy is None:
CollectSharedLazy.shared_lazy = t._lazy
else:
CollectSharedLazy.shared_lazy.extend(t._lazy)
t._lazy = CollectSharedLazy.shared_lazy
LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
shared_lazy = CollectSharedLazy.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else: else:
del res # not needed del res # not needed
# non-tensor return likely relies on the contents of the args # non-tensor return likely relies on the contents of the args
@ -168,26 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod @classmethod
def to_eager(cls, t: Any) -> Any: def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any: def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any: if _t._data is not None:
assert _t._data is not None
return _t._data return _t._data
while _t._data is None: # NOTE: there's a recursion limit in Python (usually 1000)
lt = _t._lazy.popleft()
if lt._data is not None: assert _t._func is not None
# Lazy tensor did not belong in the lazy queue. _t._args = cls._recurse_apply(_t._args, simple_to_eager)
# Weirdly only happens with Bloom models... _t._data = _t._func(*_t._args, **_t._kwargs)
# likely because tensors aren't unique in the queue.
# The final output is still the same as in eager mode,
# so it's safe to ignore this.
continue
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check # sanity check
assert lt._data is not None assert _t._data is not None
assert lt._data.dtype == lt._meta.dtype assert _t._data.dtype == _t._meta.dtype
assert lt._data.shape == lt._meta.shape assert _t._data.shape == _t._meta.shape
return _t._data return _t._data
@ -206,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod @classmethod
def from_eager(cls, t: Any) -> Any: def from_eager(cls, t: Any) -> Any:
if type(t) is cls: if type(t) is cls:
# already eager # already lazy
return t return t
elif isinstance(t, cls._tensor_type): elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t) return cls(meta=cls.eager_to_meta(t), data=t)
@ -228,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs): def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
def tofile(self, *args, **kwargs): def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self) eager = LazyNumpyTensor.to_eager(self)