From 1ca5e76f7b122ba35fc807350624a8d3fc25058a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 4 Jun 2023 13:07:22 +0300
Subject: [PATCH] fix for conds of second hires fox pass being calculated using
 first pass's networks, and add an option to revert to old behavior

---
 modules/lowvram.py    |  6 ++++++
 modules/processing.py | 32 +++++++++++++++++++++++++++++---
 modules/shared.py     |  1 +
 3 files changed, 36 insertions(+), 3 deletions(-)

diff --git a/modules/lowvram.py b/modules/lowvram.py
index e254cc13..d95bcfbf 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -15,6 +15,8 @@ def send_everything_to_cpu():
 
 
 def setup_for_low_vram(sd_model, use_medvram):
+    sd_model.lowvram = True
+
     parents = {}
 
     def send_me_to_gpu(module, _):
@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
         diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
         for block in diff_model.output_blocks:
             block.register_forward_pre_hook(send_me_to_gpu)
+
+
+def is_enabled(sd_model):
+    return getattr(sd_model, 'lowvram', False)
diff --git a/modules/processing.py b/modules/processing.py
index fae83788..b65c7beb 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -739,7 +739,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
             del samples_ddim
 
-            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+            if lowvram.is_enabled(shared.sd_model):
                 lowvram.send_everything_to_cpu()
 
             devices.torch_gc()
@@ -894,6 +894,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_negative_prompts = None
         self.hr_extra_network_data = None
 
+        self.cached_hr_uc = [None, None]
+        self.cached_hr_c = [None, None]
         self.hr_c = None
         self.hr_uc = None
 
@@ -1056,6 +1058,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             with devices.autocast():
                 extra_networks.activate(self, self.hr_extra_network_data)
 
+        with devices.autocast():
+            self.calculate_hr_conds()
+
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
 
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
@@ -1067,6 +1072,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         return samples
 
     def close(self):
+        self.cached_hr_uc = [None, None]
+        self.cached_hr_c = [None, None]
         self.hr_c = None
         self.hr_uc = None
 
@@ -1095,12 +1102,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
         self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
 
+    def calculate_hr_conds(self):
+        if self.hr_c is not None:
+            return
+
+        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data)
+        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data)
+
     def setup_conds(self):
         super().setup_conds()
 
+        self.hr_uc = None
+        self.hr_c = None
+
         if self.enable_hr:
-            self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
-            self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
+            if shared.opts.hires_fix_use_firstpass_conds:
+                self.calculate_hr_conds()
+
+            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+                with devices.autocast():
+                    extra_networks.activate(self, self.hr_extra_network_data)
+
+                self.calculate_hr_conds()
+
+                with devices.autocast():
+                    extra_networks.activate(self, self.extra_network_data)
 
     def parse_extra_network_prompts(self):
         res = super().parse_extra_network_prompts()
diff --git a/modules/shared.py b/modules/shared.py
index 7d056a4d..2bd7c6ec 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -429,6 +429,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
     "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
     "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
     "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
+    "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
 }))
 
 options_templates.update(options_section(('interrogate', "Interrogate Options"), {
-- 
GitLab