Skip to content
GitLab
菜单
为什么选择 GitLab
定价
联系销售
探索
为什么选择 GitLab
定价
联系销售
探索
登录
获取免费试用
主导航
搜索或转到…
项目
C
ComfyUI
管理
动态
成员
代码
仓库
分支
提交
标签
仓库图
比较修订版本
锁定的文件
部署
模型注册表
分析
模型实验
帮助
帮助
支持
GitLab 文档
比较 GitLab 各版本
社区论坛
为极狐GitLab 提交贡献
提交反馈
隐私声明
快捷键
?
新增功能
4
代码片段
群组
项目
Show more breadcrumbs
hanamizuki
ComfyUI
提交
e1474150
提交
e1474150
编辑于
1星期前
作者:
comfyanonymous
浏览文件
操作
下载
补丁
差异文件
Support fp8_scaled diffusion models that don't use fp8 matrix mult.
上级
e62d72e8
No related branches found
分支 包含提交
No related tags found
标签 包含提交
无相关合并请求
变更
3
隐藏空白变更内容
行内
左右并排
显示
3 个更改的文件
comfy/model_base.py
+1
-1
1 个添加, 1 个删除
comfy/model_base.py
comfy/model_detection.py
+4
-0
4 个添加, 0 个删除
comfy/model_detection.py
comfy/ops.py
+3
-1
3 个添加, 1 个删除
comfy/ops.py
有
8 个添加
和
2 个删除
comfy/model_base.py
+
1
−
1
浏览文件 @
e1474150
...
...
@@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module):
if
not
unet_config
.
get
(
"
disable_unet_model_creation
"
,
False
):
if
model_config
.
custom_operations
is
None
:
fp8
=
model_config
.
optimizations
.
get
(
"
fp8
"
,
model_config
.
scaled_fp8
is
not
Non
e
)
fp8
=
model_config
.
optimizations
.
get
(
"
fp8
"
,
Fals
e
)
operations
=
comfy
.
ops
.
pick_operations
(
unet_config
.
get
(
"
dtype
"
,
None
),
self
.
manual_cast_dtype
,
fp8_optimizations
=
fp8
,
scaled_fp8
=
model_config
.
scaled_fp8
)
else
:
operations
=
model_config
.
custom_operations
...
...
This diff is collapsed.
点击以展开。
comfy/model_detection.py
+
4
−
0
浏览文件 @
e1474150
...
...
@@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config
.
scaled_fp8
=
scaled_fp8_weight
.
dtype
if
model_config
.
scaled_fp8
==
torch
.
float32
:
model_config
.
scaled_fp8
=
torch
.
float8_e4m3fn
if
scaled_fp8_weight
.
nelement
()
==
2
:
model_config
.
optimizations
[
"
fp8
"
]
=
False
else
:
model_config
.
optimizations
[
"
fp8
"
]
=
True
return
model_config
...
...
This diff is collapsed.
点击以展开。
comfy/ops.py
+
3
−
1
浏览文件 @
e1474150
...
...
@@ -17,6 +17,7 @@
"""
import
torch
import
logging
import
comfy.model_management
from
comfy.cli_args
import
args
,
PerformanceFeature
import
comfy.float
...
...
@@ -308,6 +309,7 @@ class fp8_ops(manual_cast):
return
torch
.
nn
.
functional
.
linear
(
input
,
weight
,
bias
)
def
scaled_fp8_ops
(
fp8_matrix_mult
=
False
,
scale_input
=
False
,
override_dtype
=
None
):
logging
.
info
(
"
Using scaled fp8: fp8 matrix mult: {}, scale input: {}
"
.
format
(
fp8_matrix_mult
,
scale_input
))
class
scaled_fp8_op
(
manual_cast
):
class
Linear
(
manual_cast
.
Linear
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
def
pick_operations
(
weight_dtype
,
compute_dtype
,
load_device
=
None
,
disable_fast_fp8
=
False
,
fp8_optimizations
=
False
,
scaled_fp8
=
None
):
fp8_compute
=
comfy
.
model_management
.
supports_fp8_compute
(
load_device
)
if
scaled_fp8
is
not
None
:
return
scaled_fp8_ops
(
fp8_matrix_mult
=
fp8_compute
,
scale_input
=
True
,
override_dtype
=
scaled_fp8
)
return
scaled_fp8_ops
(
fp8_matrix_mult
=
fp8_compute
and
fp8_optimizations
,
scale_input
=
True
,
override_dtype
=
scaled_fp8
)
if
(
fp8_compute
and
...
...
This diff is collapsed.
点击以展开。
预览
0%
加载中
请重试
或
添加新附件
.
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
保存评论
取消
想要评论请
注册
或
登录