Skip to content
代码片段 群组 项目
提交 530412cb 编辑于 作者: comfyanonymous's avatar comfyanonymous
浏览文件

Refactor torch version checks to be more future proof.

上级 61c8c70c
No related branches found
No related tags found
无相关合并请求
......@@ -50,7 +50,9 @@ xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except:
pass
......@@ -227,7 +229,7 @@ if args.use_pytorch_cross_attention:
try:
if is_nvidia():
if int(torch_version[0]) >= 2:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu():
......@@ -242,7 +244,7 @@ try:
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if arch in ["gfx1100"]: #TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
......@@ -261,7 +263,7 @@ except:
pass
try:
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
......@@ -1136,11 +1138,11 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
return False
if WINDOWS:
if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
return False
return True
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册