Skip to content
代码片段 群组 项目
提交 b85fc718 编辑于 作者: Aarni Koskela's avatar Aarni Koskela
浏览文件

Fix MPS cache cleanup

Importing torch does not import torch.mps so the call failed.
上级 7b833291
No related branches found
No related tags found
无相关合并请求
......@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif has_mps() and hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():
......
import logging
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
log = logging.getLogger()
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
......@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
from torch.mps import empty_cache
empty_cache()
except Exception:
log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册