Skip to content
代码片段 群组 项目
未验证 提交 57eb54b8 编辑于 作者: Extraltodeus's avatar Extraltodeus 提交者: GitHub
浏览文件

implement CUDA device selection by ID

上级 f49c08ea
No related branches found
No related tags found
无相关合并请求
import sys, os, shlex
import contextlib
import torch
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
......@@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]: return args[x+1]
return None
def get_optimal_device():
if torch.cuda.is_available():
return torch.device("cuda")
# CUDA device selection support:
if "shared" not in sys.modules:
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
sys.argv += shlex.split(commandline_args)
device_id = extract_device_id(sys.argv, '--device-id')
else:
device_id = shared.cmd_opts.device_id
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
else:
return torch.device("cuda")
if has_mps:
return torch.device("mps")
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册