From a609bd56b4206460d1df3c3022025fc78b66718f Mon Sep 17 00:00:00 2001
From: papuSpartan <macabeg@icloud.com>
Date: Sat, 1 Apr 2023 22:18:35 -0500
Subject: [PATCH] Transition to using settings through UI instead of cmd line
 args. Added feature to only apply to hr-fix. Install package using
 requirements_versions.txt

---
 launch.py                 |  3 ---
 modules/processing.py     | 35 +++++++++++++++++++++++++++++++
 modules/sd_models.py      |  7 -------
 modules/shared.py         | 44 +++++++++++++++++++++++++++++++++++++++
 requirements_versions.txt |  1 +
 5 files changed, 80 insertions(+), 10 deletions(-)

diff --git a/launch.py b/launch.py
index 846c4c20..68e08114 100644
--- a/launch.py
+++ b/launch.py
@@ -280,9 +280,6 @@ def prepare_environment():
         elif platform.system() == "Linux":
             run_pip(f"install {xformers_package}", "xformers")
 
-    if not is_installed("tomesd") and args.token_merging:
-        run_pip(f"install tomesd")
-
     if not is_installed("pyngrok") and args.ngrok:
         run_pip("install pyngrok", "ngrok")
 
diff --git a/modules/processing.py b/modules/processing.py
index 6d9c6a8d..e115aadd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -29,6 +29,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
 
 from einops import repeat, rearrange
 from blendmodes.blend import blendLayers, BlendType
+import tomesd
 
 # some of those options should not be changed at all because they would break the model, so I removed them from options.
 opt_C = 4
@@ -500,9 +501,28 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
             if k == 'sd_vae':
                 sd_vae.reload_vae_weights()
 
+        if opts.token_merging:
+
+            if p.hr_second_pass_steps < 1 and not opts.token_merging_hr_only:
+                tomesd.apply_patch(
+                    p.sd_model,
+                    ratio=opts.token_merging_ratio,
+                    max_downsample=opts.token_merging_maximum_down_sampling,
+                    sx=opts.token_merging_stride_x,
+                    sy=opts.token_merging_stride_y,
+                    use_rand=opts.token_merging_random,
+                    merge_attn=opts.token_merging_merge_attention,
+                    merge_crossattn=opts.token_merging_merge_cross_attention,
+                    merge_mlp=opts.token_merging_merge_mlp
+                )
+
         res = process_images_inner(p)
 
     finally:
+        # undo model optimizations made by tomesd
+        if opts.token_merging:
+            tomesd.remove_patch(p.sd_model)
+
         # restore opts to original state
         if p.override_settings_restore_afterwards:
             for k, v in stored_opts.items():
@@ -938,6 +958,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         x = None
         devices.torch_gc()
 
+        # apply token merging optimizations from tomesd for high-res pass
+        # check if hr_only so we don't redundantly apply patch
+        if opts.token_merging and opts.token_merging_hr_only:
+            tomesd.apply_patch(
+                self.sd_model,
+                ratio=opts.token_merging_ratio,
+                max_downsample=opts.token_merging_maximum_down_sampling,
+                sx=opts.token_merging_stride_x,
+                sy=opts.token_merging_stride_y,
+                use_rand=opts.token_merging_random,
+                merge_attn=opts.token_merging_merge_attention,
+                merge_crossattn=opts.token_merging_merge_cross_attention,
+                merge_mlp=opts.token_merging_merge_mlp
+            )
+
         samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
         return samples
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 2c05ec17..87c49b83 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -431,13 +431,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
         with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
             sd_model = instantiate_from_config(sd_config.model)
 
-            if shared.cmd_opts.token_merging:
-                import tomesd
-                ratio = shared.cmd_opts.token_merging_ratio
-
-                tomesd.apply_patch(sd_model, ratio=ratio)
-                print(f"Model accelerated using {(ratio * 100)}% token merging via tomesd.")
-                timer.record("token merging")
     except Exception as e:
         pass
 
diff --git a/modules/shared.py b/modules/shared.py
index 5fd0eecb..d7379e24 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -427,6 +427,50 @@ options_templates.update(options_section((None, "Hidden options"), {
     "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
 }))
 
+options_templates.update(options_section(('token_merging', 'Token Merging'), {
+    "token_merging": OptionInfo(
+        False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
+        gr.Checkbox
+    ),
+    "token_merging_ratio": OptionInfo(
+        0.5, "Merging Ratio",
+        gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
+    ),
+    "token_merging_hr_only": OptionInfo(
+        True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
+        gr.Checkbox
+    ),
+    # More advanced/niche settings:
+    "token_merging_random": OptionInfo(
+        True, "Use random perturbations - Disabling might help with certain samplers",
+        gr.Checkbox
+    ),
+    "token_merging_merge_attention": OptionInfo(
+        True, "Merge attention",
+        gr.Checkbox
+    ),
+     "token_merging_merge_cross_attention": OptionInfo(
+        False, "Merge cross attention",
+        gr.Checkbox
+    ),
+    "token_merging_merge_mlp": OptionInfo(
+        False, "Merge mlp",
+        gr.Checkbox
+    ),
+    "token_merging_maximum_down_sampling": OptionInfo(
+        1, "Maximum down sampling",
+        gr.Dropdown, lambda: {"choices": ["1", "2", "4", "8"]}
+    ),
+    "token_merging_stride_x": OptionInfo(
+        2, "Stride - X",
+        gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
+    ),
+    "token_merging_stride_y": OptionInfo(
+        2, "Stride - Y",
+        gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
+    )
+}))
+
 options_templates.update()
 
 
diff --git a/requirements_versions.txt b/requirements_versions.txt
index df65431a..045230ab 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -28,3 +28,4 @@ torchsde==0.2.5
 safetensors==0.3.0
 httpcore<=0.15
 fastapi==0.94.0
+tomesd>=0.1
\ No newline at end of file
-- 
GitLab