更新
更旧
import os
import sys
import asyncio
import nodes
from PIL import Image
from io import BytesIO
try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
if request.method == "OPTIONS":
# Pre-flight request. Reply successfully:
response = web.Response()
else:
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
return cors_middleware
mimetypes.init();
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
self.number = 0
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
self.last_node_id = None
self.client_id = None
@routes.get('/ws')
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
sid = request.rel_url.query.get('clientId', '')
if sid:
# Reusing existing session, remove old
self.sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
try:
# Send initial state to the new client
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
return ws
@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/embeddings")
def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))
@routes.get("/extensions")
async def get_extensions(request):
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
def get_dir_by_type(dir_type):
if dir_type is None:
dir_type = "input"
if dir_type == "input":
type_dir = folder_paths.get_input_directory()
elif dir_type == "temp":
type_dir = folder_paths.get_temp_directory()
elif dir_type == "output":
type_dir = folder_paths.get_output_directory()
return type_dir, dir_type
overwrite = post.get("overwrite")
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
if image and image.file:
filename = image.filename
if not filename:
return web.Response(status=400)
subfolder = post.get("subfolder", "")
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
return web.Response(status=400)
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder)
filepath = os.path.join(full_output_folder, filename)
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
pass
else:
i = 1
while os.path.exists(filepath):
filename = f"{split[0]} ({i}){split[1]}"
filepath = os.path.join(full_output_folder, filename)
i += 1
if image_save_function is not None:
image_save_function(image, post, filepath)
else:
with open(filepath, "wb") as f:
f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
@routes.post("/upload/image")
async def upload_image(request):
post = await request.post()
return image_upload(post)
@routes.post("/upload/mask")
async def upload_mask(request):
post = await request.post()
def image_save_function(image, post, filepath):
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
mask_pil = Image.open(image.file).convert('RGBA')
# alpha copy
new_alpha = mask_pil.getchannel('A')
original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4)
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path
if filename[0] == '/' or '..' in filename:
return web.Response(status=400)
if output_dir is None:
type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
if 'channel' not in request.rel_url.query:
channel = 'rgba'
else:
channel = request.rel_url.query["channel"]
if channel == 'rgb':
with Image.open(file) as img:
if img.mode == "RGBA":
r, g, b, a = img.split()
new_img = Image.merge('RGB', (r, g, b))
else:
new_img = img.convert("RGB")
buffer = BytesIO()
new_img.save(buffer, format='PNG')
buffer.seek(0)
return web.Response(body=buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
elif channel == 'a':
with Image.open(file) as img:
if img.mode == "RGBA":
_, _, _, a = img.split()
else:
a = Image.new('L', img.size, 255)
# alpha img
alpha_img = Image.new('RGBA', img.size)
alpha_img.putalpha(a)
alpha_buffer = BytesIO()
alpha_img.save(alpha_buffer, format='PNG')
alpha_buffer.seek(0)
return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
else:
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(self.get_queue_info())
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out)
@routes.get("/history")
async def get_history(request):
return web.json_response(self.prompt_queue.get_history())
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2]))
return web.json_response({"prompt_id": prompt_id})
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
@routes.post("/interrupt")
async def post_interrupt(request):
nodes.interrupt_processing()
return web.Response(status=200)
@routes.post("/history")
async def post_history(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_history()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
self.prompt_queue.delete_history_item(id_to_delete)
def add_routes(self):
self.app.add_routes(self.routes)
web.static('/', self.web_root, follow_symlinks=True),
])
def get_queue_info(self):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info
async def send(self, event, data, sid=None):
message = {"type": event, "data": data}
if isinstance(message, str) == False:
message = json.dumps(message)
if sid is None:
for ws in self.sockets.values():
await ws.send_str(message)
elif sid in self.sockets:
await self.sockets[sid].send_str(message)
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
def queue_updated(self):
self.send_sync("status", { "status": self.get_queue_info() })
async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(*msg)
async def start(self, address, port, verbose=True, call_on_start=None):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None:
call_on_start(address, port)