From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 4 Jan 2023 12:35:07 +0300
Subject: [PATCH] helpful error message when trying to load 2.0 without config
 failing to load model weights from settings won't break generation for
 currently loaded model anymore

---
 modules/errors.py    | 25 +++++++++++++++++++++++--
 modules/sd_models.py | 26 ++++++++++++++++++--------
 modules/shared.py    |  9 +++++++--
 webui.py             | 12 ++++++++++--
 4 files changed, 58 insertions(+), 14 deletions(-)

diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
 import traceback
 
 
+def print_error_explanation(message):
+    lines = message.strip().split("\n")
+    max_len = max([len(x) for x in lines])
+
+    print('=' * max_len, file=sys.stderr)
+    for line in lines:
+        print(line, file=sys.stderr)
+    print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+    print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+    print(traceback.format_exc(), file=sys.stderr)
+
+    message = str(e)
+    if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+        print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+        """)
+
+
 def run(code, task):
     try:
         code()
     except Exception as e:
-        print(f"{task}: {type(e).__name__}", file=sys.stderr)
-        print(traceback.format_exc(), file=sys.stderr)
+        display(task, e)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b98b05fc..6846b74a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -278,6 +278,7 @@ def enable_midas_autodownload():
 
     midas.api.load_model = load_model_wrapper
 
+
 def load_model(checkpoint_info=None):
     from modules import lowvram, sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
         sd_config.model.params.unet_config.params.use_fp16 = False
 
     sd_model = instantiate_from_config(sd_config.model)
+
     load_model_weights(sd_model, checkpoint_info)
 
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -336,10 +338,12 @@ def load_model(checkpoint_info=None):
 def reload_model_weights(sd_model=None, info=None):
     from modules import lowvram, devices, sd_hijack
     checkpoint_info = info or select_checkpoint()
- 
+
     if not sd_model:
         sd_model = shared.sd_model
 
+    current_checkpoint_info = sd_model.sd_checkpoint_info
+
     if sd_model.sd_model_checkpoint == checkpoint_info.filename:
         return
 
@@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None):
 
     sd_hijack.model_hijack.undo_hijack(sd_model)
 
-    load_model_weights(sd_model, checkpoint_info)
-
-    sd_hijack.model_hijack.hijack(sd_model)
-    script_callbacks.model_loaded_callback(sd_model)
-
-    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
-        sd_model.to(devices.device)
+    try:
+        load_model_weights(sd_model, checkpoint_info)
+    except Exception as e:
+        print("Failed to load checkpoint, restoring previous")
+        load_model_weights(sd_model, current_checkpoint_info)
+        raise
+    finally:
+        sd_hijack.model_hijack.hijack(sd_model)
+        script_callbacks.model_loaded_callback(sd_model)
+
+        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+            sd_model.to(devices.device)
 
     print("Weights loaded.")
+
     return sd_model
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..7588c47b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.interrogate
 import modules.memmon
 import modules.styles
 import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading
+from modules import localization, sd_vae, extensions, script_loading, errors
 from modules.paths import models_path, script_path, sd_path
 
 
@@ -494,7 +494,12 @@ class Options:
             return False
 
         if self.data_labels[key].onchange is not None:
-            self.data_labels[key].onchange()
+            try:
+                self.data_labels[key].onchange()
+            except Exception as e:
+                errors.display(e, f"changing setting {key} to {value}")
+                setattr(self, key, oldval)
+                return False
 
         return True
 
diff --git a/webui.py b/webui.py
index c7d55a97..13375e71 100644
--- a/webui.py
+++ b/webui.py
@@ -9,7 +9,7 @@ from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 
-from modules import import_hook
+from modules import import_hook, errors
 from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
 from modules.paths import script_path
 
@@ -61,7 +61,15 @@ def initialize():
     modelloader.load_upscalers()
 
     modules.sd_vae.refresh_vae_list()
-    modules.sd_models.load_model()
+
+    try:
+        modules.sd_models.load_model()
+    except Exception as e:
+        errors.display(e, "loading stable diffusion model")
+        print("", file=sys.stderr)
+        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+        exit(1)
+
     shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
     shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
     shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
-- 
GitLab