Skip to content
GitLab
菜单
为什么选择 GitLab
定价
联系销售
探索
为什么选择 GitLab
定价
联系销售
探索
登录
获取免费试用
主导航
搜索或转到…
项目
C
ComfyUI
管理
动态
成员
代码
仓库
分支
提交
标签
仓库图
比较修订版本
锁定的文件
部署
模型注册表
分析
模型实验
帮助
帮助
支持
GitLab 文档
比较 GitLab 各版本
社区论坛
为极狐GitLab 提交贡献
提交反馈
隐私声明
快捷键
?
新增功能
4
代码片段
群组
项目
Show more breadcrumbs
hanamizuki
ComfyUI
提交
0bef826a
提交
0bef826a
编辑于
1星期前
作者:
comfyanonymous
浏览文件
操作
下载
补丁
差异文件
Support llava clip vision model.
上级
85ef2950
No related branches found
分支 包含提交
No related tags found
标签 包含提交
无相关合并请求
变更
4
隐藏空白变更内容
行内
左右并排
显示
4 个更改的文件
comfy/clip_model.py
+19
-1
19 个添加, 1 个删除
comfy/clip_model.py
comfy/clip_vision.py
+5
-1
5 个添加, 1 个删除
comfy/clip_vision.py
comfy/clip_vision_config_vitl_336_llava.json
+19
-0
19 个添加, 0 个删除
comfy/clip_vision_config_vitl_336_llava.json
comfy/sd1_clip.py
+18
-1
18 个添加, 1 个删除
comfy/sd1_clip.py
有
61 个添加
和
3 个删除
comfy/clip_model.py
+
19
−
1
浏览文件 @
0bef826a
...
...
@@ -211,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output
=
self
.
post_layernorm
(
x
[:,
0
,
:])
return
x
,
i
,
pooled_output
class
LlavaProjector
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
linear_1
=
operations
.
Linear
(
in_dim
,
out_dim
,
bias
=
True
,
device
=
device
,
dtype
=
dtype
)
self
.
linear_2
=
operations
.
Linear
(
out_dim
,
out_dim
,
bias
=
True
,
device
=
device
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
return
self
.
linear_2
(
torch
.
nn
.
functional
.
gelu
(
self
.
linear_1
(
x
[:,
1
:])))
class
CLIPVisionModelProjection
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
,
operations
):
super
().
__init__
()
...
...
@@ -220,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else
:
self
.
visual_projection
=
lambda
a
:
a
if
"
llava3
"
==
config_dict
.
get
(
"
projector_type
"
,
None
):
self
.
multi_modal_projector
=
LlavaProjector
(
config_dict
[
"
hidden_size
"
],
4096
,
dtype
,
device
,
operations
)
else
:
self
.
multi_modal_projector
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
x
=
self
.
vision_model
(
*
args
,
**
kwargs
)
out
=
self
.
visual_projection
(
x
[
2
])
return
(
x
[
0
],
x
[
1
],
out
)
projected
=
None
if
self
.
multi_modal_projector
is
not
None
:
projected
=
self
.
multi_modal_projector
(
x
[
1
])
return
(
x
[
0
],
x
[
1
],
out
,
projected
)
This diff is collapsed.
点击以展开。
comfy/clip_vision.py
+
5
−
1
浏览文件 @
0bef826a
...
...
@@ -65,6 +65,7 @@ class ClipVisionModel():
outputs
[
"
last_hidden_state
"
]
=
out
[
0
].
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
[
"
image_embeds
"
]
=
out
[
2
].
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
[
"
penultimate_hidden_states
"
]
=
out
[
1
].
to
(
comfy
.
model_management
.
intermediate_device
())
outputs
[
"
mm_projected
"
]
=
out
[
3
]
return
outputs
def
convert_to_transformers
(
sd
,
prefix
):
...
...
@@ -104,7 +105,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if
sd
[
"
vision_model.encoder.layers.0.layer_norm1.weight
"
].
shape
[
0
]
==
1152
:
json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"
clip_vision_siglip_384.json
"
)
elif
sd
[
"
vision_model.embeddings.position_embedding.weight
"
].
shape
[
0
]
==
577
:
json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"
clip_vision_config_vitl_336.json
"
)
if
"
multi_modal_projector.linear_1.bias
"
in
sd
:
json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"
clip_vision_config_vitl_336_llava.json
"
)
else
:
json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"
clip_vision_config_vitl_336.json
"
)
else
:
json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"
clip_vision_config_vitl.json
"
)
else
:
...
...
This diff is collapsed.
点击以展开。
comfy/clip_vision_config_vitl_336_llava.json
0 → 100644
+
19
−
0
浏览文件 @
0bef826a
{
"attention_dropout"
:
0.0
,
"dropout"
:
0.0
,
"hidden_act"
:
"quick_gelu"
,
"hidden_size"
:
1024
,
"image_size"
:
336
,
"initializer_factor"
:
1.0
,
"initializer_range"
:
0.02
,
"intermediate_size"
:
4096
,
"layer_norm_eps"
:
1e-5
,
"model_type"
:
"clip_vision_model"
,
"num_attention_heads"
:
16
,
"num_channels"
:
3
,
"num_hidden_layers"
:
24
,
"patch_size"
:
14
,
"projection_dim"
:
768
,
"projector_type"
:
"llava3"
,
"torch_dtype"
:
"float32"
}
This diff is collapsed.
点击以展开。
comfy/sd1_clip.py
+
18
−
1
浏览文件 @
0bef826a
...
...
@@ -196,8 +196,25 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
index
=
0
pad_extra
=
0
for
o
in
other_embeds
:
emb
=
o
[
1
]
if
torch
.
is_tensor
(
emb
):
emb
=
{
"
type
"
:
"
embedding
"
,
"
data
"
:
emb
}
emb_type
=
emb
.
get
(
"
type
"
,
None
)
if
emb_type
==
"
embedding
"
:
emb
=
emb
.
get
(
"
data
"
,
None
)
else
:
if
hasattr
(
self
.
transformer
,
"
preprocess_embed
"
):
emb
=
self
.
transformer
.
preprocess_embed
(
emb
,
device
=
device
)
else
:
emb
=
None
if
emb
is
None
:
index
+=
-
1
continue
ind
=
index
+
o
[
0
]
emb
=
o
[
1
]
.
view
(
1
,
-
1
,
o
[
1
]
.
shape
[
-
1
]).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
emb
=
emb
.
view
(
1
,
-
1
,
emb
.
shape
[
-
1
]).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
emb_shape
=
emb
.
shape
[
1
]
if
emb
.
shape
[
-
1
]
==
tokens_embed
.
shape
[
-
1
]:
tokens_embed
=
torch
.
cat
([
tokens_embed
[:,
:
ind
],
emb
,
tokens_embed
[:,
ind
:]],
dim
=
1
)
...
...
This diff is collapsed.
点击以展开。
预览
0%
加载中
请重试
或
添加新附件
.
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
保存评论
取消
想要评论请
注册
或
登录