Skip to content
代码片段 群组 项目
未验证 提交 f04229b8 编辑于 作者: Jedrzej Kosinski's avatar Jedrzej Kosinski 提交者: GitHub
浏览文件

Add emb_patch support to UNetModel forward (#4779)

上级 f067ad15
No related branches found
No related tags found
无相关合并请求
......@@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册