Skip to content
代码片段 群组 项目

比较版本

更改显示为版本正在合并到目标版本。了解更多关于比较版本的信息。

来源

选择目标项目
No results found

目标

选择目标项目
  • hanamizuki/comfyui
1 个结果
显示更改
源代码提交(28)
显示 259 个添加100 个删除
...@@ -19,5 +19,6 @@ ...@@ -19,5 +19,6 @@
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Extra nodes # Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink /comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered
...@@ -11,33 +11,44 @@ from dataclasses import dataclass ...@@ -11,33 +11,44 @@ from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import TypedDict, Optional from typing import TypedDict, Optional
from importlib.metadata import version
import requests import requests
from typing_extensions import NotRequired from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger
# The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message(): def frontend_install_warning_message():
req_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'requirements.txt')) """The warning message to display when the frontend version is not up to date."""
extra = "" extra = ""
if sys.flags.no_user_site: if sys.flags.no_user_site:
extra = "-s " extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
try:
import comfyui_frontend_package
except ImportError:
# TODO: Remove the check after roll out of 0.3.16
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
exit(-1)
def check_frontend_version():
"""Check if the frontend version is up to date."""
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
try:
frontend_version_str = version("comfyui-frontend-package")
frontend_version = parse_version(frontend_version_str)
with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
try:
frontend_version = tuple(map(int, comfyui_frontend_package.__version__.split(".")))
except:
frontend_version = (0,)
pass
REQUEST_TIMEOUT = 10 # seconds REQUEST_TIMEOUT = 10 # seconds
...@@ -133,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: ...@@ -133,9 +144,17 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager: class FrontendManager:
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def default_frontend_path(cls) -> str:
try:
import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
sys.exit(-1)
@classmethod @classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]: def parse_version_string(cls, value: str) -> tuple[str, str, str]:
""" """
...@@ -172,7 +191,8 @@ class FrontendManager: ...@@ -172,7 +191,8 @@ class FrontendManager:
main error source might be request timeout or invalid URL. main error source might be request timeout or invalid URL.
""" """
if version_string == DEFAULT_VERSION_STRING: if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH check_frontend_version()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
...@@ -225,4 +245,5 @@ class FrontendManager: ...@@ -225,4 +245,5 @@ class FrontendManager:
except Exception as e: except Exception as e:
logging.error("Failed to initialize frontend: %s", e) logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.") logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH check_frontend_version()
return cls.default_frontend_path()
...@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool ...@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger.addHandler(stdout_handler) logger.addHandler(stdout_handler)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
STARTUP_WARNINGS.clear()
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal, TypedDict from typing import Literal, TypedDict
from typing_extensions import NotRequired
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
...@@ -26,6 +27,7 @@ class IO(StrEnum): ...@@ -26,6 +27,7 @@ class IO(StrEnum):
BOOLEAN = "BOOLEAN" BOOLEAN = "BOOLEAN"
INT = "INT" INT = "INT"
FLOAT = "FLOAT" FLOAT = "FLOAT"
COMBO = "COMBO"
CONDITIONING = "CONDITIONING" CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER" SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS" SIGMAS = "SIGMAS"
...@@ -66,6 +68,7 @@ class IO(StrEnum): ...@@ -66,6 +68,7 @@ class IO(StrEnum):
b = frozenset(value.split(",")) b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b)) return not (b.issubset(a) or a.issubset(b))
class RemoteInputOptions(TypedDict): class RemoteInputOptions(TypedDict):
route: str route: str
"""The route to the remote source.""" """The route to the remote source."""
...@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict): ...@@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict):
refresh: int refresh: int
"""The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
class MultiSelectOptions(TypedDict):
placeholder: NotRequired[str]
"""The placeholder text to display in the multi-select widget when no items are selected."""
chip: NotRequired[bool]
"""Specifies whether to use chips instead of comma separated values for the multi-select widget."""
class InputTypeOptions(TypedDict): class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function. """Provides type hinting for the return type of the INPUT_TYPES node function.
...@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict): ...@@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict):
"""Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
""" """
remote: RemoteInputOptions remote: RemoteInputOptions
"""Specifies the configuration for a remote input.""" """Specifies the configuration for a remote input.
Available after ComfyUI frontend v1.9.7
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
control_after_generate: bool control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
options: NotRequired[list[str | int | float]]
"""COMBO type only. Specifies the selectable options for the combo widget.
Prefer:
["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
Over:
[["Option 1", "Option 2", "Option 3"]]
"""
multi_select: NotRequired[MultiSelectOptions]
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict): class HiddenInputTypeDict(TypedDict):
......
...@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N ...@@ -688,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp() sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg() t_fn = lambda sigma: sigma.log().neg()
...@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl ...@@ -762,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if solver_type not in {'heun', 'midpoint'}: if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
old_denoised = None old_denoised = None
...@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl ...@@ -808,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None denoised_1, denoised_2 = None, None
...@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl ...@@ -858,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
...@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di ...@@ -867,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
...@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di ...@@ -876,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1: if len(sigmas) <= 1:
return x return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
...@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, ...@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
x = x + d_bar * dt x = x + d_bar * dt
old_d = d old_d = d
return x return x
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
old_denoised = None
old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt()
old_denoised = denoised
return x
...@@ -159,20 +159,20 @@ class DoubleStreamBlock(nn.Module): ...@@ -159,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
) )
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims) txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
...@@ -195,12 +195,12 @@ class DoubleStreamBlock(nn.Module): ...@@ -195,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims) img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims)), img_mod2.gate, None, modulation_dims) img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks # calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims) txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims)), txt_mod2.gate, None, modulation_dims) txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
......
...@@ -244,9 +244,11 @@ class HunyuanVideo(nn.Module): ...@@ -244,9 +244,11 @@ class HunyuanVideo(nn.Module):
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else: else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None modulation_dims = None
modulation_dims_txt = None
if self.params.guidance_embed: if self.params.guidance_embed:
if guidance is not None: if guidance is not None:
...@@ -273,14 +275,14 @@ class HunyuanVideo(nn.Module): ...@@ -273,14 +275,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
else: else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
...@@ -295,10 +297,10 @@ class HunyuanVideo(nn.Module): ...@@ -295,10 +297,10 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
......
...@@ -973,11 +973,11 @@ class WAN21(BaseModel): ...@@ -973,11 +973,11 @@ class WAN21(BaseModel):
self.image_to_video = image_to_video self.image_to_video = image_to_video
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
if not self.image_to_video: noise = kwargs.get("noise", None)
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
return None return None
image = kwargs.get("concat_latent_image", None) image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"] device = kwargs["device"]
if image is None: if image is None:
...@@ -987,6 +987,9 @@ class WAN21(BaseModel): ...@@ -987,6 +987,9 @@ class WAN21(BaseModel):
image = self.process_latent_in(image) image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0]) image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video:
return image
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None: if mask is None:
mask = torch.zeros_like(noise)[:, :4] mask = torch.zeros_like(noise)[:, :4]
......
...@@ -210,12 +210,21 @@ def get_total_memory(dev=None, torch_total_too=False): ...@@ -210,12 +210,21 @@ def get_total_memory(dev=None, torch_total_too=False):
else: else:
return mem_total return mem_total
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try: try:
logging.info("pytorch version: {}".format(torch_version)) logging.info("pytorch version: {}".format(torch_version))
mac_ver = mac_version()
if mac_ver is not None:
logging.info("Mac Version {}".format(mac_ver))
except: except:
pass pass
...@@ -997,12 +1006,6 @@ def pytorch_attention_flash_attention(): ...@@ -997,12 +1006,6 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
......
...@@ -1201,7 +1201,6 @@ class ModelPatcher: ...@@ -1201,7 +1201,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks()
if hooks is not None: if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys()) model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None memory_counter = None
...@@ -1212,12 +1211,16 @@ class ModelPatcher: ...@@ -1212,12 +1211,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it # if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None) cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None: if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights: for key in cached_weights:
if key not in model_sd_keys: if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else: else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks) relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None original_weights = None
if len(relevant_patches) > 0: if len(relevant_patches) > 0:
...@@ -1228,6 +1231,8 @@ class ModelPatcher: ...@@ -1228,6 +1231,8 @@ class ModelPatcher:
continue continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter) memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
...@@ -1284,17 +1289,23 @@ class ModelPatcher: ...@@ -1284,17 +1289,23 @@ class ModelPatcher:
del out_weight del out_weight
del weight del weight
def unpatch_hooks(self) -> None: def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None
return return
keys = list(self.hook_backup.keys()) keys = list(self.hook_backup.keys())
for k in keys: if whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear() self.hook_backup.clear()
self.current_hooks = None self.current_hooks = None
def clean_hooks(self): def clean_hooks(self):
self.unpatch_hooks() self.unpatch_hooks()
......
...@@ -903,7 +903,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c ...@@ -903,7 +903,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation"] "gradient_estimation", "er_sde"]
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
......
...@@ -19,8 +19,6 @@ class Load3D(): ...@@ -19,8 +19,6 @@ class Load3D():
"image": ("LOAD_3D", {}), "image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING")
...@@ -55,8 +53,6 @@ class Load3DAnimation(): ...@@ -55,8 +53,6 @@ class Load3DAnimation():
"image": ("LOAD_3D_ANIMATION", {}), "image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING")
...@@ -82,8 +78,6 @@ class Preview3D(): ...@@ -82,8 +78,6 @@ class Preview3D():
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True
...@@ -102,8 +96,6 @@ class Preview3DAnimation(): ...@@ -102,8 +96,6 @@ class Preview3DAnimation():
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
"material": (["original", "normal", "wireframe", "depth"],),
"up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],),
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True
......
...@@ -99,12 +99,13 @@ class LTXVAddGuide: ...@@ -99,12 +99,13 @@ class LTXVAddGuide:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"vae": ("VAE",), "vae": ("VAE",),
"latent": ("LATENT",), "latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \ "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \ "tooltip": "Frame index to start the conditioning at. For single-frame images or "
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \ "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
"Negative values are counted from the end of the video."}), "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
} }
} }
...@@ -127,12 +128,13 @@ class LTXVAddGuide: ...@@ -127,12 +128,13 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors): def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond) _, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0) frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 if guide_length > 1:
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
...@@ -191,7 +193,7 @@ class LTXVAddGuide: ...@@ -191,7 +193,7 @@ class LTXVAddGuide:
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
......
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.24" __version__ = "0.3.26"
...@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated): ...@@ -634,6 +634,13 @@ def validate_inputs(prompt, item, validated):
continue continue
else: else:
try: try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
if type_input == "INT": if type_input == "INT":
val = int(val) val = int(val)
inputs[x] = val inputs[x] = val
......
...@@ -139,7 +139,7 @@ from server import BinaryEventTypes ...@@ -139,7 +139,7 @@ from server import BinaryEventTypes
import nodes import nodes
import comfy.model_management import comfy.model_management
import comfyui_version import comfyui_version
import app.frontend_management import app.logger
def cuda_malloc_warning(): def cuda_malloc_warning():
...@@ -293,28 +293,14 @@ def start_comfyui(asyncio_loop=None): ...@@ -293,28 +293,14 @@ def start_comfyui(asyncio_loop=None):
return asyncio_loop, prompt_server, start_all return asyncio_loop, prompt_server, start_all
def warn_frontend_version(frontend_version):
try:
required_frontend = (0,)
req_path = os.path.join(os.path.dirname(__file__), 'requirements.txt')
with open(req_path, 'r') as f:
required_frontend = tuple(map(int, f.readline().split('=')[-1].split('.')))
if frontend_version < required_frontend:
logging.warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), app.frontend_management.frontend_install_warning_message()))
except:
pass
if __name__ == "__main__": if __name__ == "__main__":
# Running directly, just start ComfyUI. # Running directly, just start ComfyUI.
logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
frontend_version = app.frontend_management.frontend_version
logging.info("ComfyUI frontend version: {}".format('.'.join(map(str, frontend_version))))
event_loop, _, start_all_func = start_comfyui() event_loop, _, start_all_func = start_comfyui()
try: try:
x = start_all_func() x = start_all_func()
warn_frontend_version(frontend_version) app.logger.print_startup_warnings()
event_loop.run_until_complete(x) event_loop.run_until_complete(x)
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("\nStopped server") logging.info("\nStopped server")
......
...@@ -489,7 +489,7 @@ class SaveLatent: ...@@ -489,7 +489,7 @@ class SaveLatent:
file = os.path.join(full_output_folder, file) file = os.path.join(full_output_folder, file)
output = {} output = {}
output["latent_tensor"] = samples["samples"] output["latent_tensor"] = samples["samples"].contiguous()
output["latent_format_version_0"] = torch.tensor([]) output["latent_format_version_0"] = torch.tensor([])
comfy.utils.save_torch_file(output, file, metadata=metadata) comfy.utils.save_torch_file(output, file, metadata=metadata)
...@@ -1785,14 +1785,7 @@ class LoadImageOutput(LoadImage): ...@@ -1785,14 +1785,7 @@ class LoadImageOutput(LoadImage):
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration." DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True EXPERIMENTAL = True
FUNCTION = "load_image_output" FUNCTION = "load_image"
def load_image_output(self, image):
return self.load_image(f"{image} [output]")
@classmethod
def VALIDATE_INPUTS(s, image):
return True
class ImageScale: class ImageScale:
......
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.24" version = "0.3.26"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"
......
...@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider): ...@@ -70,7 +70,7 @@ def test_get_release_invalid_version(mock_provider):
def test_init_frontend_default(): def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string) frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH assert frontend_path == FrontendManager.default_frontend_path()
def test_init_frontend_invalid_version(): def test_init_frontend_invalid_version():
...@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider(): ...@@ -84,24 +84,29 @@ def test_init_frontend_invalid_provider():
with pytest.raises(HTTPError): with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string) FrontendManager.init_frontend_unsafe(version_string)
@pytest.fixture @pytest.fixture
def mock_os_functions(): def mock_os_functions():
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \ with (
patch('app.frontend_management.os.listdir') as mock_listdir, \ patch("app.frontend_management.os.makedirs") as mock_makedirs,
patch('app.frontend_management.os.rmdir') as mock_rmdir: patch("app.frontend_management.os.listdir") as mock_listdir,
patch("app.frontend_management.os.rmdir") as mock_rmdir,
):
mock_listdir.return_value = [] # Simulate empty directory mock_listdir.return_value = [] # Simulate empty directory
yield mock_makedirs, mock_listdir, mock_rmdir yield mock_makedirs, mock_listdir, mock_rmdir
@pytest.fixture @pytest.fixture
def mock_download(): def mock_download():
with patch('app.frontend_management.download_release_asset_zip') as mock: with patch("app.frontend_management.download_release_asset_zip") as mock:
mock.side_effect = Exception("Download failed") # Simulate download failure mock.side_effect = Exception("Download failed") # Simulate download failure
yield mock yield mock
def test_finally_block(mock_os_functions, mock_download, mock_provider): def test_finally_block(mock_os_functions, mock_download, mock_provider):
# Arrange # Arrange
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
version_string = 'test-owner/test-repo@1.0.0' version_string = "test-owner/test-repo@1.0.0"
# Act & Assert # Act & Assert
with pytest.raises(Exception): with pytest.raises(Exception):
...@@ -128,3 +133,42 @@ def test_parse_version_string_invalid(): ...@@ -128,3 +133,42 @@ def test_parse_version_string_invalid():
version_string = "invalid" version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError): with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string) FrontendManager.parse_version_string(version_string)
def test_init_frontend_default_with_mocks():
# Arrange
version_string = DEFAULT_VERSION_STRING
# Act
with (
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/mocked/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/mocked/path"
mock_check.assert_called_once()
def test_init_frontend_fallback_on_error():
# Arrange
version_string = "test-owner/test-repo@1.0.0"
# Act
with (
patch.object(
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
),
patch("app.frontend_management.check_frontend_version") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/default/path"
),
):
frontend_path = FrontendManager.init_frontend(version_string)
# Assert
assert frontend_path == "/default/path"
mock_check.assert_called_once()