From b1717c0a4804f8ed3bb8cc2f3aea5d095778b447 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 2 May 2023 09:08:00 +0300
Subject: [PATCH] do not load wait for shared.sd_model to load at startup

---
 modules/sd_models.py | 54 ++++++++++++++++++++++++++++++++------------
 modules/shared.py    | 31 +++++++++++++++++++++----
 modules/ui.py        | 10 ++++----
 webui.py             | 16 ++++---------
 4 files changed, 76 insertions(+), 35 deletions(-)

diff --git a/modules/sd_models.py b/modules/sd_models.py
index 4f7613a1..59adc7cc 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -2,6 +2,8 @@ import collections
 import os.path
 import sys
 import gc
+import threading
+
 import torch
 import re
 import safetensors.torch
@@ -404,13 +406,39 @@ def repair_config(sd_config):
 sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
 sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
 
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
+
+class SdModelData:
+    def __init__(self):
+        self.sd_model = None
+        self.lock = threading.Lock()
+
+    def get_sd_model(self):
+        if self.sd_model is None:
+            with self.lock:
+                try:
+                    load_model()
+                except Exception as e:
+                    errors.display(e, "loading stable diffusion model")
+                    print("", file=sys.stderr)
+                    print("Stable diffusion model failed to load", file=sys.stderr)
+                    self.sd_model = None
+
+        return self.sd_model
+
+    def set_sd_model(self, v):
+        self.sd_model = v
+
+
+model_data = SdModelData()
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     from modules import lowvram, sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
 
-    if shared.sd_model:
-        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
-        shared.sd_model = None
+    if model_data.sd_model:
+        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+        model_data.sd_model = None
         gc.collect()
         devices.torch_gc()
 
@@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
     timer.record("hijack")
 
     sd_model.eval()
-    shared.sd_model = sd_model
+    model_data.sd_model = sd_model
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
 
@@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
     checkpoint_info = info or select_checkpoint()
 
     if not sd_model:
-        sd_model = shared.sd_model
+        sd_model = model_data.sd_model
 
     if sd_model is None:  # previous model load failed
         current_checkpoint_info = None
@@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
         del sd_model
         checkpoints_loaded.clear()
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
-        return shared.sd_model
+        return model_data.sd_model
 
     try:
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
@@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
 
     return sd_model
 
+
 def unload_model_weights(sd_model=None, info=None):
     from modules import lowvram, devices, sd_hijack
     timer = Timer()
 
-    if shared.sd_model:
-
-        # shared.sd_model.cond_stage_model.to(devices.cpu)
-        # shared.sd_model.first_stage_model.to(devices.cpu)
-        shared.sd_model.to(devices.cpu)
-        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
-        shared.sd_model = None
+    if model_data.sd_model:
+        model_data.sd_model.to(devices.cpu)
+        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+        model_data.sd_model = None
         sd_model = None
         gc.collect()
         devices.torch_gc()
diff --git a/modules/shared.py b/modules/shared.py
index 6a2b3c2b..151bab9e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -16,6 +16,7 @@ import modules.styles
 import modules.devices as devices
 from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
+from ldm.models.diffusion.ddpm import LatentDiffusion
 
 demo = None
 
@@ -600,13 +601,37 @@ class Options:
         return value
 
 
-
 opts = Options()
 if os.path.exists(config_filename):
     opts.load(config_filename)
 
+
+class Shared(sys.modules[__name__].__class__):
+    """
+    this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
+    at program startup.
+    """
+
+    sd_model_val = None
+
+    @property
+    def sd_model(self):
+        import modules.sd_models
+
+        return modules.sd_models.model_data.get_sd_model()
+
+    @sd_model.setter
+    def sd_model(self, value):
+        import modules.sd_models
+
+        modules.sd_models.model_data.set_sd_model(value)
+
+
+sd_model: LatentDiffusion = None  # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
+sys.modules[__name__].__class__ = Shared
+
 settings_components = None
-"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
+"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
 
 latent_upscale_default_mode = "Latent"
 latent_upscale_modes = {
@@ -620,8 +645,6 @@ latent_upscale_modes = {
 
 sd_upscalers = []
 
-sd_model = None
-
 clip_model = None
 
 progress_print_out = sys.stdout
diff --git a/modules/ui.py b/modules/ui.py
index 7b45f131..16c46515 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -828,7 +828,7 @@ def create_ui():
                         with FormGroup():
                             with FormRow():
                                 cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
-                                image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
+                                image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
                             denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
 
                     elif category == "seed":
@@ -1693,11 +1693,9 @@ def create_ui():
                 show_progress=info.refresh is not None,
             )
 
-        text_settings.change(
-            fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
-            inputs=[],
-            outputs=[image_cfg_scale],
-        )
+        update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
+        text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
+        demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
 
         button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
         button_set_checkpoint.click(
diff --git a/webui.py b/webui.py
index 357bf4c1..0873a26c 100644
--- a/webui.py
+++ b/webui.py
@@ -6,6 +6,8 @@ import signal
 import re
 import warnings
 import json
+from threading import Thread
+
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
@@ -191,18 +193,10 @@ def initialize():
     modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
     startup_timer.record("refresh textual inversion templates")
 
-    try:
-        modules.sd_models.load_model()
-    except Exception as e:
-        errors.display(e, "loading stable diffusion model")
-        print("", file=sys.stderr)
-        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
-        exit(1)
-    startup_timer.record("load SD checkpoint")
-
-    shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+    # load model in parallel to other startup stuff
+    Thread(target=lambda: shared.sd_model).start()
 
-    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
+    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
     shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
     shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
     shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
-- 
GitLab