Skip to content
GitLab
菜单
为什么选择 GitLab
定价
联系销售
探索
为什么选择 GitLab
定价
联系销售
探索
登录
获取免费试用
主导航
搜索或转到…
项目
C
ComfyUI
管理
动态
成员
代码
仓库
分支
提交
标签
仓库图
比较修订版本
锁定的文件
部署
模型注册表
分析
模型实验
帮助
帮助
支持
GitLab 文档
比较 GitLab 各版本
社区论坛
为极狐GitLab 提交贡献
提交反馈
隐私声明
快捷键
?
新增功能
4
代码片段
群组
项目
Show more breadcrumbs
hanamizuki
ComfyUI
提交
60653004
提交
60653004
编辑于
1个月前
作者:
comfyanonymous
浏览文件
操作
下载
补丁
差异文件
Use regular numbers for rope in lumina model.
上级
a57d635c
No related branches found
分支 包含提交
No related tags found
标签 包含提交
无相关合并请求
变更
1
隐藏空白变更内容
行内
左右并排
显示
1 个更改的文件
comfy/ldm/lumina/model.py
+8
-6
8 个添加, 6 个删除
comfy/ldm/lumina/model.py
有
8 个添加
和
6 个删除
comfy/ldm/lumina/model.py
+
8
−
6
浏览文件 @
60653004
...
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
comfy.ldm.modules.diffusionmodules.mmdit
import
TimestepEmbedder
,
RMSNorm
from
comfy.ldm.modules.attention
import
optimized_attention_masked
from
comfy.ldm.flux.layers
import
EmbedND
def
modulate
(
x
,
scale
):
...
...
@@ -92,10 +93,9 @@ class JointAttention(nn.Module):
and key tensor with rotary embeddings.
"""
x
=
torch
.
view_as_complex
(
x_in
.
float
().
reshape
(
*
x_in
.
shape
[:
-
1
],
-
1
,
2
))
freqs_cis
=
freqs_cis
.
unsqueeze
(
2
)
x_out
=
torch
.
view_as_real
(
x
*
freqs_cis
).
flatten
(
3
)
return
x_out
.
type_as
(
x_in
)
t_
=
x_in
.
reshape
(
*
x_in
.
shape
[:
-
1
],
-
1
,
1
,
2
).
float
()
t_out
=
freqs_cis
[...,
0
]
*
t_
[...,
0
]
+
freqs_cis
[...,
1
]
*
t_
[...,
1
]
return
t_out
.
reshape
(
*
x_in
.
shape
).
type_as
(
x_in
)
def
forward
(
self
,
...
...
@@ -130,6 +130,7 @@ class JointAttention(nn.Module):
xq
=
self
.
q_norm
(
xq
)
xk
=
self
.
k_norm
(
xk
)
xq
=
JointAttention
.
apply_rotary_emb
(
xq
,
freqs_cis
=
freqs_cis
)
xk
=
JointAttention
.
apply_rotary_emb
(
xk
,
freqs_cis
=
freqs_cis
)
...
...
@@ -480,7 +481,8 @@ class NextDiT(nn.Module):
assert
(
dim
//
n_heads
)
==
sum
(
axes_dims
)
self
.
axes_dims
=
axes_dims
self
.
axes_lens
=
axes_lens
self
.
rope_embedder
=
RopeEmbedder
(
axes_dims
=
axes_dims
,
axes_lens
=
axes_lens
)
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
self
.
rope_embedder
=
EmbedND
(
dim
=
dim
//
n_heads
,
theta
=
10000.0
,
axes_dim
=
axes_dims
)
self
.
dim
=
dim
self
.
n_heads
=
n_heads
...
...
@@ -550,7 +552,7 @@ class NextDiT(nn.Module):
position_ids
[
i
,
cap_len
:
cap_len
+
img_len
,
1
]
=
row_ids
position_ids
[
i
,
cap_len
:
cap_len
+
img_len
,
2
]
=
col_ids
freqs_cis
=
self
.
rope_embedder
(
position_ids
)
freqs_cis
=
self
.
rope_embedder
(
position_ids
)
.
movedim
(
1
,
2
)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape
=
list
(
freqs_cis
.
shape
)
...
...
This diff is collapsed.
点击以展开。
预览
0%
加载中
请重试
或
添加新附件
.
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
保存评论
取消
想要评论请
注册
或
登录