diff --git a/README.md b/README.md
index 9c0cd1ef7b4b92fced881f3a300d25b3df82ccf7..a5611671d22aee4c20e85fd3092a07c9a961567e 100644
--- a/README.md
+++ b/README.md
@@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
 - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
 - Security advice - RyotaK
 - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
+- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
 - (You)
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
index edd40c81fcf49ac57bf8f77bac98640c77adcd18..83d2ff0902f965ac3c69d830203ad36d0b067089 100644
--- a/modules/deepbooru_model.py
+++ b/modules/deepbooru_model.py
@@ -2,6 +2,8 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+from modules import devices
+
 # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
 
 
@@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
         t_358, = inputs
         t_359 = t_358.permute(*[0, 3, 1, 2])
         t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
-        t_360 = self.n_Conv_0(t_359_padded)
+        t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
         t_361 = F.relu(t_360)
         t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
         t_362 = self.n_MaxPool_0(t_361)
diff --git a/modules/devices.py b/modules/devices.py
index 524ec7af4227343eba416df324c239aaaa6c09bc..0981ef80a2454a649ae32643b092ec2ea0e8521e 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -79,6 +79,8 @@ cpu = torch.device("cpu")
 device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
 dtype = torch.float16
 dtype_vae = torch.float16
+dtype_unet = torch.float16
+unet_needs_upcast = False
 
 
 def randn(seed, shape):
diff --git a/modules/processing.py b/modules/processing.py
index bc541e2f75dcbde8b4c6788d7925c262cdd24391..2d186ba09d6f58f278a250f9e0d4fea09d955fe5 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -172,7 +172,8 @@ class StableDiffusionProcessing:
         midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
         midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
 
-        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
+        conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
         conditioning = torch.nn.functional.interpolate(
             self.sd_model.depth_model(midas_in),
             size=conditioning_image.shape[2:],
@@ -203,7 +204,7 @@ class StableDiffusionProcessing:
 
         # Create another latent image, this time with a masked version of the original input.
         # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
-        conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
+        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
         conditioning_image = torch.lerp(
             source_image,
             source_image * (1.0 - conditioning_mask),
@@ -211,7 +212,7 @@ class StableDiffusionProcessing:
         )
 
         # Encode the new masked image using first stage of network.
-        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
 
         # Create the concatenated conditioning tensor to be fed to `c_concat`
         conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -225,10 +226,10 @@ class StableDiffusionProcessing:
         # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
         # identify itself with a field common to all models. The conditioning_key is also hybrid.
         if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
-            return self.depth2img_image_conditioning(source_image)
+            return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
 
         if self.sampler.conditioning_key in {'hybrid', 'concat'}:
-            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+            return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
 
         # Dummy zero conditioning if we're not using inpainting or depth model.
         return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             if p.n_iter > 1:
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
 
-            with devices.autocast():
+            with devices.autocast(disable=devices.unet_needs_upcast):
                 samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
 
             x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
@@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
 
         image = torch.from_numpy(batch_images)
         image = 2. * image - 1.
-        image = image.to(shared.device)
+        image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
 
         self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
 
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 18daf8c14020a3d4affd75a0d0082feb58d05f90..88c94e54adc55c30ea18b051e90c54f8960a6a0b 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -1,4 +1,8 @@
 import torch
+from packaging import version
+
+from modules import devices
+from modules.sd_hijack_utils import CondFunc
 
 
 class TorchHijackForUnet:
@@ -28,3 +32,28 @@ class TorchHijackForUnet:
 
 
 th = TorchHijackForUnet()
+
+
+# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
+def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
+    for y in cond.keys():
+        cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+    with devices.autocast():
+        return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
+class GELUHijack(torch.nn.GELU, torch.nn.Module):
+    def __init__(self, *args, **kwargs):
+        torch.nn.GELU.__init__(self, *args, **kwargs)
+    def forward(self, x):
+        if devices.unet_needs_upcast:
+            return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
+        else:
+            return torch.nn.GELU.forward(self, x)
+
+unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+if version.parse(torch.__version__) <= version.parse("1.13.1"):
+    CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
+    CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
+    CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81b169ae993425e2c30475923cf3bbe5bac1d2e
--- /dev/null
+++ b/modules/sd_hijack_utils.py
@@ -0,0 +1,28 @@
+import importlib
+
+class CondFunc:
+    def __new__(cls, orig_func, sub_func, cond_func):
+        self = super(CondFunc, cls).__new__(cls)
+        if isinstance(orig_func, str):
+            func_path = orig_func.split('.')
+            for i in range(len(func_path)-2, -1, -1):
+                try:
+                    resolved_obj = importlib.import_module('.'.join(func_path[:i]))
+                    break
+                except ImportError:
+                    pass
+            for attr_name in func_path[i:-1]:
+                resolved_obj = getattr(resolved_obj, attr_name)
+            orig_func = getattr(resolved_obj, func_path[-1])
+            setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+        self.__init__(orig_func, sub_func, cond_func)
+        return lambda *args, **kwargs: self(*args, **kwargs)
+    def __init__(self, orig_func, sub_func, cond_func):
+        self.__orig_func = orig_func
+        self.__sub_func = sub_func
+        self.__cond_func = cond_func
+    def __call__(self, *args, **kwargs):
+        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
+            return self.__sub_func(self.__orig_func, *args, **kwargs)
+        else:
+            return self.__orig_func(*args, **kwargs)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 12083848066285f882dfcb6ccf87940639fc1ec9..7c98991a30b32c48f905abd5dacffd83d8c4f895 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
 
         if not shared.cmd_opts.no_half:
             vae = model.first_stage_model
+            depth_model = getattr(model, 'depth_model', None)
 
             # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
             if shared.cmd_opts.no_half_vae:
                 model.first_stage_model = None
+            # with --upcast-sampling, don't convert the depth model weights to float16
+            if shared.cmd_opts.upcast_sampling and depth_model:
+                model.depth_model = None
 
             model.half()
             model.first_stage_model = vae
+            if depth_model:
+                model.depth_model = depth_model
 
         devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
         devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+        devices.dtype_unet = model.model.diffusion_model.dtype
+        devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
 
         model.first_stage_model.to(devices.dtype_vae)
 
@@ -372,6 +380,8 @@ def load_model(checkpoint_info=None):
 
     if shared.cmd_opts.no_half:
         sd_config.model.params.unet_config.params.use_fp16 = False
+    elif shared.cmd_opts.upcast_sampling:
+        sd_config.model.params.unet_config.params.use_fp16 = True
 
     timer = Timer()
 
diff --git a/modules/shared.py b/modules/shared.py
index 5f713bee78a5e1a3564fb1561d231c716f6dfa72..4ce1209b09c2764e6fcb5111692a54d9fe60783e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
 parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
 parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
 parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
+parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
 parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
 parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
 parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")