Skip to content
代码片段 群组 项目
server.py 15.7 KB
更新 更旧
  • 了解如何忽略特定修订
  • import os
    import sys
    import asyncio
    import nodes
    
    import folder_paths
    
    import execution
    
    import uuid
    import json
    
    import glob
    
    from PIL import Image
    from io import BytesIO
    
    
    try:
        import aiohttp
        from aiohttp import web
    except ImportError:
        print("Module 'aiohttp' not installed. Please install it via:")
        print("pip install aiohttp")
        print("or")
        print("pip install -r requirements.txt")
        sys.exit()
    
    
    comfyanonymous's avatar
    comfyanonymous 已提交
    import mimetypes
    
    EllangoK's avatar
    EllangoK 已提交
    from comfy.cli_args import args
    
    
    @web.middleware
    async def cache_control(request: web.Request, handler):
        response: web.Response = await handler(request)
        if request.path.endswith('.js') or request.path.endswith('.css'):
            response.headers.setdefault('Cache-Control', 'no-cache')
        return response
    
    
    EllangoK's avatar
    EllangoK 已提交
    def create_cors_middleware(allowed_origin: str):
        @web.middleware
        async def cors_middleware(request: web.Request, handler):
            if request.method == "OPTIONS":
                # Pre-flight request. Reply successfully:
                response = web.Response()
            else:
                response = await handler(request)
    
            response.headers['Access-Control-Allow-Origin'] = allowed_origin
            response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
            response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
            response.headers['Access-Control-Allow-Credentials'] = 'true'
            return response
    
        return cors_middleware
    
    EllangoK's avatar
    EllangoK 已提交
    
    
    class PromptServer():
        def __init__(self, loop):
    
            PromptServer.instance = self
    
    comfyanonymous's avatar
    comfyanonymous 已提交
    
    
            mimetypes.init(); 
            mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
    
            self.prompt_queue = None
            self.loop = loop
            self.messages = asyncio.Queue()
            self.number = 0
    
    EllangoK's avatar
    EllangoK 已提交
    
            middlewares = [cache_control]
    
            if args.enable_cors_header:
                middlewares.append(create_cors_middleware(args.enable_cors_header))
    
    EllangoK's avatar
    EllangoK 已提交
    
            self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
    
            self.sockets = dict()
            self.web_root = os.path.join(os.path.dirname(
    
                os.path.realpath(__file__)), "web")
    
            routes = web.RouteTableDef()
    
            self.routes = routes
    
            self.last_node_id = None
            self.client_id = None
    
    
            @routes.get('/ws')
            async def websocket_handler(request):
                ws = web.WebSocketResponse()
                await ws.prepare(request)
    
                sid = request.rel_url.query.get('clientId', '')
                if sid:
                    # Reusing existing session, remove old
                    self.sockets.pop(sid, None)
                else:
    
                self.sockets[sid] = ws
    
                try:
                    # Send initial state to the new client
                    await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
    
                    # On reconnect if we are the currently executing client send the current node
                    if self.client_id == sid and self.last_node_id is not None:
                        await self.send("executing", { "node": self.last_node_id }, sid)
                        
    
                    async for msg in ws:
                        if msg.type == aiohttp.WSMsgType.ERROR:
                            print('ws connection closed with exception %s' % ws.exception())
                finally:
    
                    self.sockets.pop(sid, None)
    
                return ws
    
            @routes.get("/")
            async def get_root(request):
                return web.FileResponse(os.path.join(self.web_root, "index.html"))
    
            @routes.get("/embeddings")
            def get_embeddings(self):
    
                embeddings = folder_paths.get_filename_list("embeddings")
    
                return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings)))
    
    
            @routes.get("/extensions")
            async def get_extensions(request):
                files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
                return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
    
    
            def get_dir_by_type(dir_type):
                if dir_type is None:
                    type_dir = folder_paths.get_input_directory()
                elif dir_type == "input":
                    type_dir = folder_paths.get_input_directory()
                elif dir_type == "temp":
                    type_dir = folder_paths.get_temp_directory()
                elif dir_type == "output":
                    type_dir = folder_paths.get_output_directory()
    
                return type_dir
    
    
            def image_upload(post, image_save_function=None):
    
                image = post.get("image")
    
                overwrite = post.get("overwrite")
    
                image_upload_type = post.get("type")
                upload_dir = get_dir_by_type(image_upload_type)
    
    
                if image and image.file:
                    filename = image.filename
                    if not filename:
                        return web.Response(status=400)
    
    
                    subfolder = post.get("subfolder", "")
                    full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
    
                    if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
                        return web.Response(status=400)
    
                    if not os.path.exists(full_output_folder):
                        os.makedirs(full_output_folder)
    
    
                    split = os.path.splitext(filename)
    
                    filepath = os.path.join(full_output_folder, filename)
    
    
                    if overwrite is not None and (overwrite == "true" or overwrite == "1"):
                        pass
                    else:
                        i = 1
                        while os.path.exists(filepath):
                            filename = f"{split[0]} ({i}){split[1]}"
                            filepath = os.path.join(full_output_folder, filename)
                            i += 1
    
                    if image_save_function is not None:
                        image_save_function(image, post, filepath)
                    else:
                        with open(filepath, "wb") as f:
                            f.write(image.file.read())
    
                    return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
    
                else:
                    return web.Response(status=400)
    
    
            @routes.post("/upload/image")
            async def upload_image(request):
                post = await request.post()
                return image_upload(post)
    
    
            @routes.post("/upload/mask")
            async def upload_mask(request):
                post = await request.post()
    
    
                def image_save_function(image, post, filepath):
                    original_pil = Image.open(post.get("original_image").file).convert('RGBA')
    
                    mask_pil = Image.open(image.file).convert('RGBA')
    
                    # alpha copy
                    new_alpha = mask_pil.getchannel('A')
                    original_pil.putalpha(new_alpha)
    
                    original_pil.save(filepath, compress_level=4)
    
                return image_upload(post, image_save_function)
    
            @routes.get("/view")
    
            async def view_image(request):
    
                if "filename" in request.rel_url.query:
    
                    filename = request.rel_url.query["filename"]
                    filename,output_dir = folder_paths.annotated_filepath(filename)
    
                    # validation for security: prevent accessing arbitrary path
                    if filename[0] == '/' or '..' in filename:
                        return web.Response(status=400)
    
                    if output_dir is None:
                        type = request.rel_url.query.get("type", "output")
                        output_dir = folder_paths.get_directory_by_type(type)
    
    
                        return web.Response(status=400)
    
    
                    if "subfolder" in request.rel_url.query:
    
                        full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
    
                        if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
    
                            return web.Response(status=403)
                        output_dir = full_output_dir
    
                    filename = os.path.basename(filename)
                    file = os.path.join(output_dir, filename)
    
                    if os.path.isfile(file):
    
                        if 'channel' not in request.rel_url.query:
                            channel = 'rgba'
                        else:
                            channel = request.rel_url.query["channel"]
    
                        if channel == 'rgb':
                            with Image.open(file) as img:
                                if img.mode == "RGBA":
                                    r, g, b, a = img.split()
                                    new_img = Image.merge('RGB', (r, g, b))
                                else:
                                    new_img = img.convert("RGB")
    
                                buffer = BytesIO()
                                new_img.save(buffer, format='PNG')
                                buffer.seek(0)
    
                                return web.Response(body=buffer.read(), content_type='image/png',
                                                    headers={"Content-Disposition": f"filename=\"{filename}\""})
    
                        elif channel == 'a':
                            with Image.open(file) as img:
                                if img.mode == "RGBA":
                                    _, _, _, a = img.split()
                                else:
                                    a = Image.new('L', img.size, 255)
    
                                # alpha img
                                alpha_img = Image.new('RGBA', img.size)
                                alpha_img.putalpha(a)
                                alpha_buffer = BytesIO()
                                alpha_img.save(alpha_buffer, format='PNG')
                                alpha_buffer.seek(0)
    
                                return web.Response(body=alpha_buffer.read(), content_type='image/png',
                                                    headers={"Content-Disposition": f"filename=\"{filename}\""})
                        else:
                            return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
    
    
                return web.Response(status=404)
    
            @routes.get("/prompt")
            async def get_prompt(request):
                return web.json_response(self.get_queue_info())
    
            @routes.get("/object_info")
            async def get_object_info(request):
                out = {}
                for x in nodes.NODE_CLASS_MAPPINGS:
                    obj_class = nodes.NODE_CLASS_MAPPINGS[x]
                    info = {}
                    info['input'] = obj_class.INPUT_TYPES()
                    info['output'] = obj_class.RETURN_TYPES
    
                    info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
    
                    info['name'] = x
                    info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
    
                    info['description'] = ''
                    info['category'] = 'sd'
                    if hasattr(obj_class, 'CATEGORY'):
                        info['category'] = obj_class.CATEGORY
                    out[x] = info
                return web.json_response(out)
    
            @routes.get("/history")
            async def get_history(request):
    
                return web.json_response(self.prompt_queue.get_history())
    
    
            @routes.get("/queue")
            async def get_queue(request):
                queue_info = {}
                current_queue = self.prompt_queue.get_current_queue()
                queue_info['queue_running'] = current_queue[0]
                queue_info['queue_pending'] = current_queue[1]
                return web.json_response(queue_info)
    
            @routes.post("/prompt")
            async def post_prompt(request):
                print("got prompt")
                resp_code = 200
                out_string = ""
                json_data =  await request.json()
    
                if "number" in json_data:
                    number = float(json_data['number'])
                else:
                    number = self.number
                    if "front" in json_data:
                        if json_data['front']:
                            number = -number
    
                    self.number += 1
    
                if "prompt" in json_data:
                    prompt = json_data["prompt"]
    
                    valid = execution.validate_prompt(prompt)
    
                    extra_data = {}
                    if "extra_data" in json_data:
                        extra_data = json_data["extra_data"]
    
                    if "client_id" in json_data:
                        extra_data["client_id"] = json_data["client_id"]
                    if valid[0]:
    
                        prompt_id = str(uuid.uuid4())
                        self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2]))
                        return web.json_response({"prompt_id": prompt_id})
    
                    else:
                        resp_code = 400
                        out_string = valid[1]
                        print("invalid prompt:", valid[1])
    
                return web.Response(body=out_string, status=resp_code)
            
            @routes.post("/queue")
            async def post_queue(request):
                json_data =  await request.json()
                if "clear" in json_data:
                    if json_data["clear"]:
                        self.prompt_queue.wipe_queue()
                if "delete" in json_data:
                    to_delete = json_data['delete']
                    for id_to_delete in to_delete:
                        delete_func = lambda a: a[1] == int(id_to_delete)
                        self.prompt_queue.delete_queue_item(delete_func)
                        
                return web.Response(status=200)
    
    pythongosssss's avatar
    pythongosssss 已提交
    
            @routes.post("/interrupt")
            async def post_interrupt(request):
                nodes.interrupt_processing()
                return web.Response(status=200)
    
    
            @routes.post("/history")
            async def post_history(request):
                json_data =  await request.json()
                if "clear" in json_data:
                    if json_data["clear"]:
    
                        self.prompt_queue.wipe_history()
    
                if "delete" in json_data:
                    to_delete = json_data['delete']
                    for id_to_delete in to_delete:
    
                        self.prompt_queue.delete_history_item(id_to_delete)
    
    
                return web.Response(status=200)
    
            
        def add_routes(self):
            self.app.add_routes(self.routes)
    
            self.app.add_routes([
                web.static('/', self.web_root),
            ])
    
        def get_queue_info(self):
            prompt_info = {}
            exec_info = {}
            exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
            prompt_info['exec_info'] = exec_info
            return prompt_info
    
        async def send(self, event, data, sid=None):
            message = {"type": event, "data": data}
           
            if isinstance(message, str) == False:
                message = json.dumps(message)
    
            if sid is None:
                for ws in self.sockets.values():
                    await ws.send_str(message)
            elif sid in self.sockets:
                await self.sockets[sid].send_str(message)
    
        def send_sync(self, event, data, sid=None):
            self.loop.call_soon_threadsafe(
                self.messages.put_nowait, (event, data, sid))
    
        def queue_updated(self):
            self.send_sync("status", { "status": self.get_queue_info() })
    
        async def publish_loop(self):
            while True:
                msg = await self.messages.get()
                await self.send(*msg)
    
    
        async def start(self, address, port, verbose=True, call_on_start=None):
    
            runner = web.AppRunner(self.app)
            await runner.setup()
            site = web.TCPSite(runner, address, port)
            await site.start()
    
            if address == '':
                address = '0.0.0.0'
    
    comfyanonymous's avatar
    comfyanonymous 已提交
            if verbose:
                print("Starting server\n")
                print("To see the GUI go to: http://{}:{}".format(address, port))
    
            if call_on_start is not None:
                call_on_start(address, port)