diff --git a/modules/shared.py b/modules/shared.py
index a9e28b9c405efaa0517f3774342697c8391d45c1..e83cbcdff1bf804b3ac1530a1d2ff43beb178c6a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
 parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
 parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
 parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
 
 cmd_opts = parser.parse_args()
 restricted_opts = {
@@ -147,9 +148,9 @@ class State:
         self.interrupted = True
 
     def nextjob(self):
-        if opts.show_progress_every_n_steps == -1: 
+        if opts.show_progress_every_n_steps == -1:
             self.do_set_current_image()
-            
+
         self.job_no += 1
         self.sampling_step = 0
         self.current_image_sampling_step = 0
@@ -198,7 +199,7 @@ class State:
             return
         if self.current_latent is None:
             return
-            
+
         if opts.show_progress_grid:
             self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
         else:
diff --git a/webui.py b/webui.py
index 81df09dd29306762138503ab282a55c6bcbbc7b5..3788af0ba194cef188b9036bf42dce6f5e65373c 100644
--- a/webui.py
+++ b/webui.py
@@ -5,6 +5,7 @@ import importlib
 import signal
 import threading
 from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 
 from modules.paths import script_path
@@ -93,6 +94,11 @@ def initialize():
     signal.signal(signal.SIGINT, sigint_handler)
 
 
+def setup_cors(app):
+    if cmd_opts.cors_allow_origins:
+        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
+
+
 def create_api(app):
     from modules.api.api import Api
     api = Api(app, queue_lock)
@@ -114,6 +120,7 @@ def api_only():
     initialize()
 
     app = FastAPI()
+    setup_cors(app)
     app.add_middleware(GZipMiddleware, minimum_size=1000)
     api = create_api(app)
 
@@ -147,6 +154,8 @@ def webui():
         # runnnig its code. We disable this here. Suggested by RyotaK.
         app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
 
+        setup_cors(app)
+
         app.add_middleware(GZipMiddleware, minimum_size=1000)
 
         if launch_api: