diff --git a/modules/devices.py b/modules/devices.py
new file mode 100644
index 0000000000000000000000000000000000000000..25008a04211a10f9305e6e2e6bb6367b04f1323b
--- /dev/null
+++ b/modules/devices.py
@@ -0,0 +1,12 @@
+import torch
+
+
+# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
+has_mps = getattr(torch, 'has_mps', False)
+
+def get_optimal_device():
+  if torch.cuda.is_available():
+      return torch.device("cuda")
+  if has_mps:
+      return torch.device("mps")
+  return torch.device("cpu")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index e86ad775a50afc050095c9c8a9734f47688b1318..7f3baf31ffce32cab099201573f2020fbb527e8d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -9,12 +9,13 @@ from PIL import Image
 import modules.esrgam_model_arch as arch
 from modules import shared
 from modules.shared import opts
+from modules.devices import has_mps
 import modules.images
 
 
 def load_model(filename):
     # this code is adapted from https://github.com/xinntao/ESRGAN
-    pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None)
+    pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
     crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
 
     if 'conv_first.weight' in pretrained_net:
diff --git a/modules/lowvram.py b/modules/lowvram.py
index bd1174915f19da6cb0fee65d326c02e25ea42f70..079386c36c6b251c46e8fc74bc2a9286866376a9 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,13 +1,9 @@
 import torch
+from modules.devices import get_optimal_device
 
 module_in_gpu = None
 cpu = torch.device("cpu")
-if torch.has_cuda:
-    device = gpu = torch.device("cuda")
-elif torch.has_mps:
-    device = gpu = torch.device("mps")
-else:
-    device = gpu = torch.device("cpu")
+device = gpu = get_optimal_device()
 
 def setup_for_low_vram(sd_model, use_medvram):
     parents = {}
diff --git a/modules/shared.py b/modules/shared.py
index 6ca9106ca66b4f3de73dd69da924c957d2d0e68b..74b0ad89d7c2423d64b5273610792e1fe0aa3d6d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -9,6 +9,7 @@ import tqdm
 
 import modules.artists
 from modules.paths import script_path, sd_path
+from modules.devices import get_optimal_device
 import modules.styles
 
 config_filename = "config.json"
@@ -43,12 +44,8 @@ parser.add_argument("--ui-config-file", type=str, help="filename to use for ui c
 
 cmd_opts = parser.parse_args()
 
-if torch.has_cuda:
-    device = torch.device("cuda")
-elif torch.has_mps:
-    device = torch.device("mps")
-else:
-    device = torch.device("cpu")
+device = get_optimal_device()
+
 batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
 parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram