现有的去噪模型的一个问题是同一个网络去噪网络同时要“编码”噪声输入来提取低频语义信息,之后通过“解码”获得高频信息。这是一个优化困境(optimization dilemma):编码低频语义必须要减少高频信号。DDT提出将二者解耦,从原来的only-decoder转变为encoder-decoder架构。DDT包含一个Condition Encoder用于低频语义信息的提取,Velocity Decoder用于速度预测生成更高质量的图像。

DDT将Condition Encoder的输出称为self-condition,在此基础上,他们采用了sharing strategy提高推理速度,这通常伴随极小的性能损失。DDT仅训练256个epoch,4倍于REPA的训练速度即可达到1.31的FID。

sharing strategy是指相邻去噪步的生成结果具有一致性(local consistency),假设推理步数N,编码器给定计算成本K,则共享率计算为\(1-(K/N)\)。因此,有\(|\Phi|=K\) 是self-conditon重新计算的时间步的集合:

$$ \[\begin{equation} z_t = \begin{cases} z_{t-{\delta}t}, & \text{if }t \notin \Phi \\ Encoder(x_t, t, y), & \text{if }t \in \Phi \end{cases} \end{equation}\] $$

\(|\Phi|=K\) 通过动态规划求解 ????????

ICLR2025 扩散模型中的去噪过程中,在模型内部存在有意义的表征,尽管这些表征的质量仍然落后于通过最近的自监督学习方法学习到的表征。论文训练大规模扩散模型用于生成的一个主要瓶颈在于有效地学习这些表示。通过结合高质量的外部视觉表征,而不是仅仅依靠扩散模型来独立学习,训练可以变得更容易。作者通过引入一种称为表示对齐REPresentation Alignment(REPA)的简单正则化来研究这一点,该正则化将去噪网络中噪声输入隐藏状态的投影与从外部预训练的视觉编码器获得的干净图像表示对齐。当应用于DiT和SiT时,REPA在训练效率和生成质量方面都有了显著的改进。REPA可以将SiT训练的速度提高17.5倍以上,与在不到400K步的情况下训练7M步的SiT-XL\2模型的性能(无cfg)相匹配。在最终生成质量方面,作者使用cfg和guidance interval实现了FID=1.42的最先进结果。

Transformer

训练扩散模型的主要挑战源于需要学习高质量的内部表征。REPA表明:当生成式扩散模型得到来自另一个模型(例如自监督视觉编码器)的外部高质量表征的支持时,其性能可以得到大幅提升。论文使用 DINOv2 作为外部自监督视觉编码器提供表征对齐。

Transformer

作者使用Linear Probing 和 CKNNA 来评估扩散模型的内部表征能力,图2a所示,作者观察到预训练扩散变压器的隐藏状态表示,在第20层实现了相当高的线性探测峰(t表示dropout ratio used in linear probing)。然而,它的性能仍然远远低于DINOv2,这表明两种表示之间存在很大的语义差距。此外,在达到这个峰值后,线性探测性能迅速下降,这表明扩散转换器必须从仅仅专注于学习语义丰富的表示转变为生成具有高频细节的图像。

在图2b中,作者使用CKNNA报告了SiT和DINOv2之间的代表性比对。特别是,SiT模型表示已经表现出比MAE更好的一致性。然而,绝对对齐得分仍然低于其他自监督学习方法(例如,MoCov3 vs. DINOv2)。这些结果表明,虽然扩散变压器表示与自监督视觉表示表现出一定的一致性,但一致性仍然很弱。

如图2c所示,作者通过更大的模型和扩展的训练观察到改进的对齐。然而,对齐仍然很低,并且没有达到其他自监督视觉编码器(例如,MoCov3和DINOv2)之间观察到的水平,即使经过大量的7M迭代训练。

Transformer

作者发现在SiT上,第8层施加REPA,Linear Probing 和 CKNNA 达到了峰值。

\[ \begin{equation} \mathcal{L}_{\text{REPA}}(\theta, \phi) := - \mathbb{E}_{\mathbf{x}_*, \epsilon, t} \left[ \frac{1}{N} \sum_{n=1}^{N} \operatorname{sim}(\mathbf{y}_*^{[n]}, \, h_\phi(h_t^{[n]})) \right] \end{equation} \]

REPA对齐损失如上,一张干净的图像\(x_*\),预训练编码器(如DINOV2)定义为\(f\)\(y_*=f(x_*) \in \mathcal{R}^{N×D}\),N和D分别是pathch的数量和embedding dimension of encoder output,\(h_t=f_{\theta}(z_t)\)是扩散模型的输出,并通过可学习映射层\(h_{\phi}(h_t)\)是一个MLP。

Transformer

图5a分析了多种target encoder的效果,DINOV2实现最低FID的同时,预训练之后的DINOV2+SiT-L,在线性探测上实现了更高的线性探测准确率。

视频虚拟试穿(video virtual try-on,VVT)技术在电子商务、广告和娱乐等领域的应用前景广阔,引起了学术界的广泛关注。然而,大多数现有的端到端方法严重依赖于稀缺的以服装为中心的成对数据集,无法有效地利用先进视觉模型的先验性和测试-时间输入(test-time inputs)。使得在不受约束的场景中准确保存细粒度的服装细节和保持时间一致性具有挑战性。为了应对这些挑战,我们提出了DreamVVT,这是一个精心设计的两阶段框架,建立在扩散变压器(dit)之上,它本质上能够利用各种不成对的以人为中心的数据来增强现实世界场景的适应性。为了进一步利用来自预训练模型和测试时间输入的先验知识,在第一阶段,我们从输入视频中采样代表性帧并利用结合视觉语言模型(VLM)的多帧试戴模型,进一步合成高保真和语义一致的关键帧试穿图像。这些图像作为后续视频生成的补充外观指导。在第二阶段,从输入内容中提取骨架映射以及细粒度的运动和外观描述,然后将这些与关键帧试镜图像一起输入到使用LoRA适配器增强的预训练视频生成模型中。这确保了看不见的区域的长期时间一致性,并使高度可信的动态运动成为可能。大量的定量和定性实验表明,DreamVVT在真实场景中保留服装的详细信息和时间稳定性方面优于现有的方法。

LatentSync
图1: DreamVVT can generate high-fidelity and temporally coherent virtual try-on videos for diverse garments and in unconstrained scenarios. Specifically, the first row shows its ability to handle complex human motions like runway walks and 360-degree rotations;the second row illustrates robustness to complex backgrounds and challenging camera movements; the third row highlights visually coherent try-on results for cartoon characters with real garments.

现有的问题

现有的方法难以准确地保留细粒度的服装细节,并在不受约束的场景中保持时间一致性,例如复杂的主体或摄像机运动、动态场景和不同的角色风格。我们认为,这些限制主要源于对端到端训练范式的依赖,而这固有地限制了对不成对数据、先进视觉模型先验和推理阶段附加信息的有效利用。

首先,这些方法所严重依赖的配对的服装视频数据不足,其中大多数是在均匀的室内环境中收集的。这通常会导致服装视觉保真度降低和时间不稳定性增加,特别是对于任意服装和复杂的视频输入。此外,在不同的现实世界场景中收集大规模成对的服装视频数据集仍然极具挑战性。

其次,这些方法采用预训练的文本到视频生成模型,将空间错位的服装图像逐帧变形到人身上。然而,这种方法破坏了预训练模型平滑时空建模的固有能力,使模型收敛更具挑战性。此外,在预训练的模型中对所有参数进行完全微调时,由于数据量有限,容易破坏预训练的先验,这反过来又降低了生成视频的质量和时间稳定性。即使在大规模数据集和各种视频任务上训练了很大一部分模型参数,统一的视频创建和编辑方法仍然难以准确地保留服装细节并保持时间一致性,这主要是由于缺乏针对虚拟试穿的任务特定设计。

第三,在推理阶段,仅提供服装的正面图像来指导虚拟试穿过程,往往会导致当人转身或摄像机视点(camera viewpoint)发生重大变化时,不可见区域的结果不可信。

为了解决这些问题,作者引入了DreamVVT,这是一种基于扩散转换器(Diffusion transformer, dit)的改进的阶段式框架,它本质上能够利用来自不同来源的不成对的以人为中心的数据来提高现实场景中的泛化。为了进一步利用预训练模型和推理过程中的先验信息,

在第一阶段,作者首先从输入视频中采样具有显著运动变化的关键帧,然后使用视觉语言模型(VLM)生成文本描述,将输入服装映射到每个关键帧。将这些描述以及服装图像和其他相关条件提供给配备LoRA适配器的多帧试穿模型,从而获得每个关键帧的高保真度和语义一致的试穿图像。这些图像作为后续视频生成的补充外观指导。

在第二阶段,作者采用时间平滑姿态引导器(pose guider)进行骨骼特征编码,并利用先进的视频大语言模型(video LLM)从输入内容中提取细粒度的动作描述和其他高级视觉信息。这些特征以及空间对齐的关键帧试镜图像随后作为输入提供给使用LoRA适配器增强的预训练视频生成模型。通过利用大规模视频生成模型的预训练先验,该模型在野外场景中具有增强的泛化能力。此外,结合多个关键帧试穿图像和精确的运动指导,可以生成长期的(long-term)虚拟试穿视频,表现出强烈的时间一致性和高度可信的动态运动。此外,还引入了多任务学习策略来保持不同模态条件的可控性。

DreamVVT

DreamVVT采用了一种基于大型(large-scale)扩散变压器的分阶段框架,实现了无约束场景下的高保真虚拟试穿视频生成。如图2所示,它由两个连续的阶段组成。 在第一阶段,从输入视频中抽取具有显著运动变化的帧作为关键帧,然后开发多帧试穿模型,在保持内容一致性和保留精细细节的同时将服装图像适配到这些关键帧。 在第二阶段,我们提出了一种改进的视频生成模型,该模型综合了基于关键帧试穿图像、姿态特征和文本描述的可信(plausible)试穿视频。

LatentSync
图2: Overview of DreamVVT.

输入条件

Pose Conditions

为了实现真实人物和卡通人物的虚拟试穿,作者采用RTMPose作为鲁棒的和高效的姿态表示。考虑到原始视频中的角色可能只占据有限的空间区域,直接将其原始分辨率的姿势序列输入到试衣模型中可能会导致生成视频中显著的服装细节丢失,这主要是由于空间降采样。为了解决这个问题,作者首先使用统一宽度和高度的跟踪边界框裁剪每一帧,从而隔离每一帧中的人物主体所在区域(character region)。

Agnostic Masks

以往的方法大多是通过直接扩大视频中分割的服装区域来生成agnostic mask,这很容易导致原始服装风格线索的泄露。为了缓解这个问题,作者使用人体边界框和扩展的姿势骨架来生成与服装无关的掩码(clothing-agnostic mask),有效地防止信息泄漏,同时尽可能多地保留原始背景。

Agnostic Images

通过应用agnostic masks,遮挡输入人物主体视频或图像中的服装区域,从而生成agnostic images。

Garment Images

对于服装图像输入,作者首先使用显著性分割检测模型提取前景区域,然后用白色像素填充去除背景区域。为了进一步促进服装细节的保存,作者根据提取的分割计算一个紧密的边界框,然后裁剪感兴趣的区域。最后,裁剪后的图像在输入到网络之前被调整到一个特定的分辨率。

阶段1:关键帧的高保真试穿

作者选择具有显著运动变化的帧,为视频生成提供更全面的指导。最初,考虑到大多数输入的服装图像都是从正面角度捕获的,作者预先定义了一个A姿势(A-pose)的正面视图人物图像作为锚帧。 随后,作者通过测量每个视频帧和锚帧各自骨骼关节方向向量之间的余弦距离来计算它们之间的运动相似度。这种相似性进一步通过主体在整个框架中的面积比例来加权,从而产生最终分数。最后,根据帧的最终分数按降序排序,并在最小分数区间约束下进行逆序搜索,得到一组信息冗余最小的关键图像。

给定选定的关键帧,作者利用具有最小可学习参数集的扩散转换器 \(G*\),来自于预训练的Seedream模型 \(G\) (增加了lora),生成最终的多帧试戴结果。通过将注意力模块与可插拔的LoRA集成,作者修改了 \(G\) 中的每个 \(MMDiT\) 块,并引入了一个额外的参数共享网络分支来处理实现后的参考图像输入。值得注意的是,\(G∗\) 将多关键帧图像条件对与精心设计的一致图像指令一起作为输入,这澄清了不同的输入组件,并有助于指导模型合成理想的结果。具体而言,作者首先通过并行网络架构对每个条件输入进行标记,以对齐不同的模态,然后在注意过程中通过Q, K, V 的交换聚合关键帧的信息。该机制确保了每个条件输入和关键帧中间特征之间的鲁棒的信息交互,从而实现了具有一致细节的连贯多帧试穿结果。对于文本输入,作者使用字节内部的 Seed1.5-VL 进行详细描述,包括每个关键帧的服装类别、材料和图案。随后,引入文本对齐过程,要求VLM重写并收集所有文本结果,进一步增强关键帧描述的一致性。

阶段2:多模态引导虚拟试穿视频生成

虚拟视频试穿模型基于预训练的图像到视频生成框架,该框架采用堆叠顺序MMDiT块,每个块都集成了文本和视频流。为了准确地重建输入视频中的身体运动,作者提取了相应的二维骨架序列。裁剪后,具有时间关注的定制姿态引导器(pose guider)将逐帧骨架图转换为在每个时间帧上与潜空间噪声分辨率匹配的平滑姿势潜在向量。同样,将裁剪后的 agnostic images 输入视频 VAE 编码器以获得 agnostic latents,并将裁剪后的agnostic masks 调整为与 agnostic latents 相同的分辨率。然后,将 agnostic latents、调整大小的 agnostic masks、nosie latents 和 pose latents 沿着通道维度进行 concatenate,并 patchfy 为视频 tokens,记为 \(F_{vid} \in R^{l_v×c}\)(其中\(l_v = t × h × w\), \(t = T/4\), \(h = H/16\), \(w = W/16\), \(t, h, w\)为输入视频的形状)。由于姿势骨架(pose skeletons)只捕获粗粒度的身体运动,不能完全表示细粒度的服装交互,我们使用 Qwen2.5-VL 提取属性解纠缠文本描述(attribute-disentangled textual description)——包含详细运动描述和高级视觉信息(在推理过程中,与外观相关的描述被替换为与目标服装对应的描述)。这些文本描述随后被Qwen LLM 处理成 text tokens,记为 \(F_{text} \in R_{l_t×c_t}\)。对于 appearance branch,首先由视频VAE编码器对关键帧尝试图像逐帧处理,提取 image latents,然后将其转化为image tokens,表示为 \(F_{img} \in R^{l_i×c}\) (其中\(l_i = k × h × w\), k 为关键帧的数量)。为了保持模型的时空建模和快速粘附能力,作者冻结了文本流的参数。轻量级的LoRA适配器,只包含10%的可训练参数,被插入到直接从视频流复制的冻结视频流和图像流中,具有共享内存。随着视频和图像tokens的通道数的增加,视频和图像流的输入投影层被设置为可训练的。最后,所有这些标记集通过各自的 QKV投影层 进行处理,然后沿着 l 维进行连接。结果序列被输入到一个完整的自注意力机制模块中,该模块使模型能够自适应地将视觉内容与文本描述跨空间和时间维度对齐。在自注意力机制操作之后,通过索引将 joint tokens 解复用为text、image、video tokens,这些 tokens 随后由以下DiT块处理。在DiT主干内进行多次去噪迭代后,生成试穿视频 tokens,之后通过 Video VAE decoder被解码成视频序列。然后采用高效的拉普拉斯金字塔融合方法,将生成的试穿视频无缝地融合到原始视频的相应区域中。在训练过程中,作者引入了多任务学习策略,其中一个任务(例如,文本到视频,带文本的姿势和关键帧到视频)是基于预定义的 probabilistic schedule 随机选择的,以充分利用各种模式的互补优势。

视频虚拟试穿(video virtual try-on,VVT)技术在电子商务、广告和娱乐等领域的应用前景广阔,引起了学术界的广泛关注。之前关于视频尝试的研究主要集中在将产品服装图像转移到具有简单人体姿势的视频中,而在复杂动作方面表现不佳。为了更好地保存服装细节,这些方法配备了一个额外的服装编码器,导致更高的计算资源消耗。目前该领域核心挑战:(1)利用服装编码器的视频试戴功能,同时降低计算需求;(2)确保人体各部位合成的时间一致性,特别是在快速运动时。

本文提出了Dynamic Try-On,包含两个模块:

1)Dynamic Feature Fusion Module (DFFM)动态特征融合模块:通过DiT主干提取并存储整合服装特征,保存服装细节。

2)Limb-aware Dynamic Attention Module (LDAM)肢体感知动态注意模块以及用于保存人的姿势和身份的轻量级身份保存编码器(ID Encoder):有效地跨帧传递身体信息,并产生时间一致的视频。

模型结构

Dynamic
图1: Overview of the proposed Dynamic Try-On.

ID-Encoder

在虚拟试穿中,作者认为其可以视为一个修复问题(inpainting problem),提出了一个四元组 \({x_a, d_p, m_c, c}\),c代表目标服装将被应用于指定的人物视频\(x\)\(x_a\)代表衣服不可知的图像(cloth-agnostic image),\(d_p\)代表姿势骨架(pose skeleton),\(m_c\)代表要修复的掩码。在实现ID-Encoder时,作者采用类似ControlNet的方式,通过一个SiT-DiT块保存人物姿势、身份和背景信息。

Dynamic Attention Mechanism

动态注意力机制包括DFFM(动态特征融合模块)和LDAM(肢体感知动态注意模块)。从总体上看,网络先执行Garment的输入,经过一次前向计算得到特征图存储在Feature Bank中,之后网络进行第二次去噪前向传播,第二次前向传播包括了一个类似ControlNet的ID-Encoder模块来融合人物ID身份信息,在每个SiT-DiT块中,通过交叉注意力融合目标服装特征的信息,实现虚拟试穿。

DFFM

本质上,作者将SiT-DiT的去噪主干网进行共用,用来并行提取服装特征,如图1左边所示(灰色部分不参与服装特征提取过程),作者将每次经过SiT-DiT块之后的服装特征向量保存在一个Feature Bank中,这个过程需要一次完整的前向传播。之后,图1右边的去噪DiT-SiT主干网络进行去噪,在每次经过一个DiT-SiT块时,通过交叉注意力机制与Feature Bank中事先存好的服装特征进行融合,值得注意的是,左侧块与右侧块是对应的关系,Feature Bank中总共存储了N个特征图用于计算N次交叉注意。

LDAM

如图2所示,本质上,作者希望网络能够关注视频中人物骨架运动过程的信息,作者通过骨架图和掩码找到人体骨骼在视频帧中所对应的具体的patch,然后作者将这些patch拼接在一起进行自注意力,注意力计算结果则直接加到原来的特征图中。算法1说明了详细的计算过程。

Dynamic
图2: Visulization of Limb-aware Dynamic Attention Module.
Dynamic

Multi-Stage Training

作者使用了预训练模型Opensora的模型权重。

  • 在第一阶段作者只训练空间自注意力和交叉注意力层,任务是根据指定的服装图像来重建人物图像。
  • 在第二阶段,作者训练ID-Encoder,训练目标与第一阶段相同,所有参数都进行训练。这一阶段LDAM没有加入到网络中。
  • 在第三阶段,作者加入LDAM,并且只训练新加入的层。尽管这里没有提到,但第三阶段应该是也训练了时间自注意力层。

端到端音频条件潜在扩散模型(ldm)已被广泛应用于音频驱动的人像动画,证明了它们在生成逼真的高分辨率谈话视频方面的有效性。当前的口型同步(lip-sync)任务中口型同步的精度并不理想。本文指出了其根本性的问题:捷径学习(shortcut Learning problem)——模型经常学习视觉-视觉(visual-visual)的捷径而忽略关键的试听相关性(audio-visual correlation)。作者将SyncNet 集成到audio-conditioned LDMs中监督强制学习试听相关性。作者认为SyncNet对对口型精度的影响至关重要,进行了全面的实证分析来确定影响SyncNet收敛的关键因素。

本文主要贡献:

1)提出了LatentSync——第一个使用音频条件在高分辨率视频上实现端到端逼真口型同步的方法,提出了TREPA(Temporal Representation Alignment)增强视频的时间一致性.

2)发现了捷进学习问题,探索了将SyncNet 集成到audio-conditioned LDMs中监督强制学习试听相关性的解决方法。

3)进行了全面的实证研究分析SyncNet收敛的关键因素,提出了稳定收敛的StableSyncNet。

LatentSync
图1: The overview of LatentSync framework.

LatentSync Framework

LatentSync架构如图1所示。U-Net的大部分参数直接来自预训练模型SD 1.5。为了适配新任务,作者更换了输入层(conv_in),以及加入的交叉注意力层(cross-attn)采用了随机初始化。U-Net的输入包含13个通道,分别是mask, mask frames, reference frames以及Noised latents。

Audio layers

音频信息通过Whisper Encoder提取音频输入得到音频嵌入。通常口型动作收到当前音频以及周围音频的共同影响,作者将音频输入特征\(A^{(f)}\)定义为:\(A^{(f)} = {A^{(f)-m},...,A^{(f)},...A^{(f)+m}}\),其中f表示第f帧,m是一侧周围音频特征的数量。最后通过交叉注意力层合并到U-Net中。

Affine transformation and fixed mask

作者使用仿射变换将人脸正面化,这有利于模型学习面部特征。同时在掩膜的选取上,作者采用覆盖几乎整个脸部的掩膜大小,希望最小化模型对捷径的依赖。

LatentSync
图2: The illustration of affine transformation and fixed mask.

SyncNet supervision

SyncNet需要输入图像进行监督,而LDMs每次预测的是噪声,因此作者使用一步采样的结果作为图像输入:

\[ (\hat{z}_0) = (z_t - \sqrt{1-\bar{\alpha}}\epsilon_{\theta}(z_t)) / \sqrt{\bar{\alpha}_t} \]

作者讨论了SyncNet的两种空间设计,在后续的分析中证明了像素空间下具有更快的收敛速度。作者认为可能是在VAE编码过程中唇区信息丢失造成的。

LatentSync
图3: Two methods to add SyncNet supervision to latent diffusion models.

两阶段训练策略

第一阶只训练U-Net学习从参考图像中提取视觉特征来修复图像,训练U-Net的全部参数,但不加入时间层:

\[ \mathcal{L}_{\text{simple}} = \mathbb{E}_{x, A, \epsilon \sim \mathcal{N}(0,1), t} \left[\left\lVert \epsilon - \epsilon_\theta \big( z_t, t, \tau_\theta(A) \big) \right\rVert_2^2 \right] \] 其中A是输入音频,\(\epsilon_{\theta}(z_t, t, \tau_{\theta(A)})\)是预测噪声,_{(A)}是音频特征提取器。

在训练的第二阶段,作者只训练时间层(temporal layer)和音频层(audio layer),而冻结 U-Net 的其他参数。假设有 16 帧解码后的视频片段 \(\mathcal{D}(\hat{z}_0)_{f:f+16}\) 以及对应的音频序列 \(a_{f:f+16}\),则 同步网络损失(SyncNet Loss) 定义为:

\[ \mathcal{L}_{\text{sync}} = \mathbb{E}_{x, a, \epsilon, t}\Big[\text{SyncNet}\big(\mathcal{D}(\hat{z}_0)_{f:f+16}, \, a_{f:f+16}\big)\Big] \] 其中 \(\mathcal{D}\) 表示 VAE 解码器。由于唇形同步(lipsync)任务需要生成细节区域(如嘴唇、牙齿、胡须等),为了提升生成图像的视觉质量,作者引入 LPIPS 损失:

\[ \mathcal{L}_{\text{lpips}} = \mathbb{E}_{x, \epsilon, t} \Big[\big\lVert \mathcal{V}_l(\mathcal{D}(\hat{z}_0)_f) - \mathcal{V}_l(x_f) \big\rVert_2^2 \Big] \] 其中 \(\mathcal{V}_l(\cdot)\) 表示从 预训练的 VGG 网络第 \(l\) 层提取的特征。

此外,为了增强时间一致性,还引入了提出的 TREPA 损失。 最终,第二阶段的总损失函数定义为:

\[ \mathcal{L}_{\text{total}} = \lambda_1 \mathcal{L}_{\text{simple}} + \lambda_2 \mathcal{L}_{\text{sync}} + \lambda_3 \mathcal{L}_{\text{lpips}} + \lambda_4 \mathcal{L}_{\text{trepa}} \]

TREPA(Temporal Representation Alignment)

考虑到一半的损失只能改善生成图像的内容质量而并不关注时序一致性,作者使用自监督视频模型VideoMAE-v2提取时序表征。设 \(\mathcal{T}\) 是一个自监督视频模型编码器(self-supervised video encoder),其输出为在 projection head 之前的嵌入表示。TREPA 损失定义为:

\[ \mathcal{L}_{\text{trepa}} = \mathbb{E}_{x, \epsilon, t} \Big[ \big\lVert \mathcal{T}(\mathcal{D}(\hat{z}_0)_{f:f+16}) - \mathcal{T}(x_{f:f+16}) \big\rVert_2^2 \Big] \] 其中,采用均方误差(MSE)来衡量生成视频片段与真实视频片段的时间表示(temporal representations)之间的差异。在计算 MSE 之前,特征表示会先经过 \(\ell_2\) 归一化。

SyncNet收敛性分析

这个部分探索了为什么SyncNet训练不收敛的原因,在我看来更像是寻找合适超参数的过程。作者最终确定batch size大小为1024,Embedding dimension为2048,number of frames为16。除此之外,作者介绍了SyncNet的结构和数据处理的一些过程,在仓库中可以看到详细的数据处理管道。

实验结果

LatentSync
表1: Quantitative comparisons on HDTF and VoxCeleb2.

表1结果表明作者所提出的方法在HDTF数据集上取得的领先性能。更多的实验结果,实验设置细节和消融分析请阅读原文。

稳定扩散模型通常由数十亿个参数组成,文本生成图像任务中在生成高质量图像时需要很高的计算要求。为了提高效率,最近的研究包括剪枝、蒸馏、量化的相关工作致力于此。BK-SDM证明了在稳定扩散模型上剪枝再蒸馏训练的巨大潜力。在蒸馏训练中,第一次在特征和输出水平上构建蒸馏损失函数进行训练,用更少的资源和成本训练出具有竞争力的T2I模型,且更易部署在边缘设备上。

手工剪枝

BK-SDM通过经验手动设计了三种型号的剪枝后SDM。具体结构如图1所示。

Transformer
图1: Block removal from the denoising U-Net.

在图2的相关性分析中,作者认为CLIP分数无法正确反映层重要性,(a)中删除较多的Attn层而CLIP分数并没有显著下降,CLIP评分修剪敏感性的重要性标准会导致注意力块的过度修剪,最终结果不如手工设计的剪枝网络性能,结果对比展示在图3中。

Transformer
图2: Importance of (a) each block and (b) each group of paired/triplet blocks.
Transformer
图3: Different block-level criteria at similar parameters. Taylor pruning removes solely the inner blocks, leading to insufficient reduction in MACs. Our method attains a favorable compromise between performance an inference speed. Results on MS-COCO 30K.

作者还发现删除中间块(Middle Block)对于SD模型的性能影响是微小的,如图4所示。

Transformer
图4: Minor impact of removing the midstage from the U-Net. Results without retraining.

剪枝后蒸馏训练

为了恢复剪枝后模型的性能,BK-SDM采用蒸馏训练来对齐教师模型的输出。BK-SDM在feature-level和output-level上计算损失来训练学生模型。

Transformer
图5: Distillation-based retraining. The block-removed U-Net is trained effectively through the guidance of the original U-Net.

具体来说,损失函数由三部分组成,分别是任务损失,特征损失和输出损失:

\[ \begin{equation} L_{Task} = \mathbb{E}_{z,\epsilon,y,t}[||\epsilon - \epsilon_S(z_t, y, t)||_{2}^2] \end{equation} \]

\[ \begin{equation} L_{FeatKD} = \mathbb{E}_{z,\epsilon,y,t}[||\epsilon_T(z_t, y, t) - \epsilon_S(z_t, y, t)||_{2}^2] \end{equation} \]

\[ \begin{equation} L_{OutKD} = \mathbb{E}_{z,\epsilon,y,t}[\sum_{l} ||f_T^l(z_t, y, t) - f_S^l(z_t, y, t)||_{2}^2] \end{equation} \]

其中\(f_T^l(o)\)\(f_S^l(o)\)分别代表预先定义的第l层的输出特征图。最终,损失公式如下:

\[ \begin{equation} L = L_{Task} + \lambda_{FeatKD} L_{FeatKD} + \lambda_{OutKD} L_{OutKD} \end{equation} \]

其中{FeatKD}和{OutKD}是可设置的超参数。

CVPR2024

潜在扩散模型(Latent Diffusion Models,LDMs)已经成为最强大的生成模型之一,其在有限计算资源的条件下展示了出众的结果。尽管如此,这些庞大的模型仍然难以在资源受限的环境下部署。LD-Pruner提出了一种严格的剪枝评估指标来对LDM进行结构化修剪。并在剪枝后使用知识蒸馏恢复模型性能,使其保持和教师模型同样的推理性能的同时大幅减少部署资源和推理时间。

Transformer
图1: Samples generated using compressed models.

一种新颖的评估指标

本文提出了一种全新的评估指标,旨在潜空间中(latent)对层重要性进行排序。假设\(L_{orig}\)\(L_{mod}\)分别表示原始层集合和修剪后层集合,\(N_{gen}\)表示N次前向传播的输出潜在表示(latent representation)。定义第i次前向传播输出的潜在向量为\(l_{orig,i}\)\(l_{mod,i}\)

Transformer
图2: Overview of LD-Pruner.

得分公式包括两部分:距离均值\(avg_{dist}\)和距离方差\(std_{dist}\)

\[ \begin{equation} avg_{dist} = |avg_{orig}-avg_{mod}|_2 \\ std_{dist} = |std_{orig}-std_{mod}|_2 \end{equation} \]

其中 \(||_{2}\)表示欧几里得范数。

\[ \begin{equation} avg_{orig} = \frac{1}{N_{gen}} \sum_{i = 1}^{N_{gen}} l_{orig,i} \\ avg_{mod} = \frac{1}{N_{gen}} \sum_{i = 1}^{N_{gen}} l_{mod,i} \end{equation} \]

\[ \begin{equation} std_{orig} = \sqrt{\frac{1}{N_{gen}} \sum_{i = 1}^{N_{gen}} {avg_{orig} - l_{orig,i}}^2} \\ std_{mod} = \sqrt{\frac{1}{N_{gen}} \sum_{i = 1}^{N_{gen}} {avg_{mod} - l_{mod,i}}^2} \end{equation} \]

最终,定义修改后模型的得分公式(scoring formula):

\[ \begin{equation} score = avg_{dist} + std_{dist} \end{equation} \]

也许你会问为什么是均值和方差相加,而不是单独用均值或方差,或者二者相乘。实际上,如何组合以及用哪种方式,是要有前置实验做相关性分析。

图3: Qualitative comparison of the impact of various combination methods for average and standard deviation in our proposed scoring metric, with SD. The results are without finetuning.

如图3所示,作者在定性比较中发现,具有较小均值的层,文章里称之为operator(算子)延迟了图像退化为噪声的时间,即删除带来的损害更小;而较小的方差意味着可以保留更多的图像细节。

图4: Quantitative comparison of the impact of various combination methods for average and standard deviation in our proposed scoring metric, with UIG. The FID is measured after 20k iterations of finetuning.

在定量实验中,在FID的比较中,相加略优于相乘,最终作者选用了相加作为得分公式。

图5: Efficient Pruning for LDMs.

LD-Pruner算法如上,其中k类似于剪枝率,表示要剪枝多少层网络。

剪枝后训练

通常剪枝都会伴随着模型的重新训练,本文中使用知识蒸馏(Knowledge Distillation, KD)训练剪枝后的模型。LD-Pruner沿用BK-SDM中提出的蒸馏方法,即在特征水平和输出水平上对齐教师模型。特征水平上,在每个stage(Up、Down、Middle)之后的结果进行对齐;输出水平上,是指最终模型的输出结果进行对齐。

实验

作者分别在文生图(T2I)、无条件图像生成(UIG)、无条件音频生成(UAG)任务上实现他们的方法,证明了LD-Pruner的有效性。

图6: Comparison of different models for T2I Generation, on the MS-COCO 256 × 256 validation set. Speedup values are measured relatively to SD-v1.4.
图7: Qualitative comparison on zero-shot MS-COCO benchmark on T2I.

作者在MS-COCO数据集上详细展示了定性和定量的实验结果,T2I上剪枝后的模型FID上甚至超越了基线模型。

图8: Compression performance on UAG task with AudioDiffusion. When finetuning, we proceed for 12k steps.

图8展示了在UAG上的实验结果,剪枝之后的模型微调后,在FAD上超越了基线模型。

图9: FID scores for our compressed model (31 operators modified) trained from scratch and with preserved pre-training weights, for UIG on CelebA-HQ 256 × 256. In both case, the exact sametraining is applied. The FID for the original model is 13.85.

图9展示了UIG实验结果,对比了不加载预训练权重的模型,证明了剪枝后加载预训练权重的重要性,实验在CeleBA-HQ上进行。

归一化技术从神经网络发展到今天,已经诞生了众多种类。批归一化,层归一化,组归一化等。这篇文章总结了系统的归一化技术。

批归一化|Batch Normalization

总所周知,神经网络都是通过梯度下降算法来优化训练,而更流行的一种方式是小批量随机梯度下降,介于one-sample随机梯度下降和full-samples随机梯度下降之间。然而,不合适的学习率和参数初始化有可能导致训练坍塌,BN正是为了缓解这一问题而设计。

在提出BN的论文中,作者认为BN有效减弱了内部协变量偏倚(ICS)问题——as the change in the distribution of network activations due to the change in network parameters during training。

具体来说,BN layer对每个Batch做归一化。假设一个Batch有n个样本,每个样本有d个特征,那么对于第k个特征归一化的结果:

\[ \begin{equation} \hat{x}^{(k)} = \frac{x^{(k)} - E[x^{(k)}]}{\sqrt{\mathrm{Var}[x^{(k)}]}} \end{equation} \]

其中E和Var分别表示在第k个特征上这个批量中所有的样本的均值和方差。

这其中也存在一些问题,比如新的Batch分布可能会破坏前面已经学习到的数据分布。所以BN还添加了两个可以学习的变量\(\beta\)\(\gamma\):

\[ \begin{equation} y^{(k)} = \gamma^{(k)}\hat{x}^{(k)} + \beta^{(k)} \end{equation} \]

\(\gamma^{(k)} = \sqrt{\mathrm{Var}[x^{(k)}]}\) 并且 \(\beta^{(k)} = E[x^{(k)}]\) 时,\(y^{(k)} = x^{(k)}\)

在训练时,BN layer每次使用的均值和方差是从每个batch上单独计算的,同时,BN layer会维护一个全局均值和方差,在推理时,BN layer使用全局均值与方差对每个样本做归一化。

类比:有一堆书,有 N 本,每本有 C 页,每页有 H 行,每行 W 个字符,而每个字符可以看成具有不同信息量或者复杂度的数值。BN 求均值时,相当于把这些书按页码一一对应地加起来(例如第1本书第36页,第2本书第36页......),再除以每个页码下的字符总数:N×H×W,因此可以把 BN 看成求“所有书对应页的平均字符信息量”的操作,求标准差时也是同理。

层归一化|Layer Normalization

批归一化要求每个Batch中的样本长度必须一致,然而在现实情况下,文本等序列数据并不是等长的1,尽管会在数据预处理时统一大小,但很明显填充位置并不代表真实数据的分布,这样的归一化是没有效果的,甚至损害模型的收敛。BN在面对RNN这种网络时显得束手无策;同时,BN在面对小批次数据时,如资源受限的环境中,性能也会大幅下降。

令人惊喜的是,LN完美的解决了这两个问题;BN对同一特征的同一通道所有批次内样本做归一化,而LN则对同一样本的同一通道的所有特征做归一化。本质上LN避开了在多个样本上进行归一化的操作,只在一个样本的不同通道上对所有特征做归一化。

假设某个样本在经过某一层之后a,特征数(神经元数量)为H,均值与方差:

\[ \begin{equation} \mu = \frac{1}{H} \sum_{i = 1}^{H} a_i \quad \sigma = \sqrt{\frac{1}{H} \sum_{i = 1}^{H} (a_i - \mu)^2} \end{equation} \]

归一化之后\(\hat{a}\):

\[ \hat{\mathbf{a}} \frac{\mathbf{a}-\mu}{\sqrt{\left(\sigma\right)^{2}+\epsilon}} \]

与BN类似,LN也加入了增益(gain)\(g\)和偏置(bias)\(b\),对\(\hat{a}\)进行二次缩放:

\[ \begin{equation} y = f(\mathbf{g} \odot \hat{\mathbf{a}}+\mathbf{b}) \end{equation} \]

类比:有一堆书,有 N 本,每本有 C 页,每页有 H 行,每行 W 个字符,而每个字符可以看成具有不同信息量或者复杂度的数值。LN 求均值时,相当于把每本书所有页的字符数相加(例如第1本书第36页,第2本书第40页......),再除以每本书的字符总数:C×H×W,因此可以把 LN 看成求“每本书的平均字符信息量”的操作,求标准差时也是同理。

实例归一化|Instance Normalization

IN最初用于图像的风格迁移,作者认为不同通道的均值与方差会影响最终生成图像的风格,因此将图像在每个通道上做归一化,并使用参考的风格图像的每个通道上的均值和方差做“去归一化”。

类比:有一堆书,有 N 本,每本有 C 页,每页有 H 行,每行 W 个字符,而每个字符可以看成具有不同信息量或者复杂度的数值。IN 求均值时,相当于把每本书每一页的字符数单独统计(例如第1本书第36页,第2本书第40页......),再除以每页的字符总数:H×W,因此可以把 IN 看成求“每本书每页的平均字符信息量”的操作,求标准差时也是同理。

组归一化|Group Normalization

介于LN和IN之间,将通道数分组,主要目的是为了节约显存。

类比:有一堆书,有 N 本,每本有 C 页,每页有 H 行,每行 W 个字符,而每个字符可以看成具有不同信息量或者复杂度的数值。GN 求均值时,相当于把每本书分成很多小册子,每个小册子有很多页,对每个小册子的字符数求和,再除以每个册子的字符总数:(C/G)×H×W,因此可以把 GN 看成求“每个册子的平均字符信息量”的操作,求标准差时也是同理。

总结与回顾

在众多归一化方法中,只有BN多维护一个全局均值和方差(注意均值与方差的全局统计量是通过指数滑动平均实现的),其他方法只维护可学习参数\(\gamma\)\(\beta\),在训练和推理中计算的统计量均值与方差用完就会丢弃。

学习笔记|大模型微调方法

最近在看一些模型微调的相关方法,阅读了不少论文,想做个小结,分享给需要的朋友。

直接微调

直接微调是指在不改变模型结构的基础上 ,加载预训练模型并在有限的数据集上微调训练,通常选取更小的学习率,通常的一些技术有全量微调,冻结部分参数等。

  • 全量微调

全量微调是指每次更新模型的全部参数,这种方式在已有预训练权重的基础上进行,训练和预训练一样,但是由于预训练模型已经学习到了一定的先验知识,所以会降低微调的训练成本,模型可以更快地收敛。

  • 部分微调(冻结)

部分微调是指冻结模型的一部分层的参数不参与梯度更新的过程,这种方式在某些情况下被证明具有比全量微调更大的优势。除了简单的冻结层之外,不同的论文提出了不同的冻结策略,比如只微调偏置bias,或者只对归一化层等进行调整。

适应性微调(Adapter)

Adapter最初由论文“Parameter-Efficient Transfer Learning for NLP”提出,策略是在Transformer的模型结构中在注意力层和前馈网络层之后加入两层全连接层和非线性激活函数,被称之为Adapter,第一层将d维向量映射到m维向量,第二层将m维向量映射回d维,以保持后续正确的残差连接,一个Adapter的总参数量为\((dm+m)+(md+d)\).如果你了解Transformer的实现,那么并不难知道Adapter加在哪里。

图1: Adapter

之后,借鉴Adapter的方法,论文“LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS”提出了LoRA低秩自适应微调技术。LoRA的策略是在Transformer的自主力层中,对q、v(注意力层中计算的三个向量分别是q、k、v)加入额外的类似于漏斗的全连接网络,与Adapter相似,第一个全连接层将维度缩小到r,第二个全连接层将维度恢复,之后将计算的q值与LORA相加。更精炼的概括是,LoRA在自注意力子层之间的q和v结果中添加了两个低秩矩阵进行微调。在研究中,LORA通常并不归类于Adapter中。

图2: Lora

提示微调(P-Tuning)

与P-Tuning相关的工作:

  • 论文 “The Power of Scale for Parameter-Efficient Prompt Tuning",Prompt Tuning的首创。

  • 论文 ”Li and Liang, "Prefix-Tuning: Optimizing Continuous Prompts for Generation",提出前缀调整Prefix-Tuning。

  • 论文“P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks”在二者之上改进提出了P-Tuning v2。

以上是目前主流的三大类微调方法,每种类别下又有这各式各样不同的小设计,尽管很多方法最先出自NLP领域,但随着技术的不断发展,CV与NLP领域的交叉融合趋势不断演进,期待未来更多高效的微调方法的出现。

摘要

现有蒸馏方法展示了图像领域中一步生成的潜力,但它们仍然遭受严重的质量下降。论文提出了对抗后训练(Adversarial Post-Training, APT)来进行一步视频生成。为了提高训练稳定性和质量,论文改进了模型架构和训练过程,以及一种R1近似正则化目标。

GAN回顾

生成对抗网络(Generative Adversarial Networks)是生成模型的一种,由生成器(Generator)和判别器(Discriminator)两个模块组成。生成器的训练目标是最小化生成数据分布和真实数据分布的差异;而判别器的目标是最大化生成数据分布和真实数据分布之间的差异(损失),在实际设计中,一般会取负数来实现最小化的优化目标。

在传统的 GAN 中,判别器使用交叉熵损失来区分真实样本和生成样本,这种方法容易导致训练不稳定,如梯度消失或梯度爆炸。

GAN的优化目标:

\[ \begin{equation} \min_G \max_D \mathbb{E}_{x \sim p_{\textup{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))] \end{equation} \]

我们希望判别器对于真实样本的输出尽可能接近1,即log 1 = 0; 同时希望判别器对于生成样本的输出尽可能接近0,即log (1-D(G(z))) = 0。

这符合交叉熵损失的逻辑,假设真实样本为1,生成样本为0,判别器的目标就是尽可能预测生成样本的概率为0,,预测真实样本的概率为1。

\[ \begin{equation} {\cal L}_D = - \mathbb{E}_{x \sim p_{\rm data}} [\log D(x)] - \mathbb{E}_{z \sim p_z} [\log (1-D(G(z)))] \end{equation} \]

这里期望就是求和求平均,等价于优化目标,最小化负对数似然即是最大化对数似然。

对于生成器,优化目标就是欺骗判别器,使得log D(G(z))尽可能接近1,即最大化对数似然,但一般简便起见,优化为以下形式:

\[ \begin{equation} {\cal L}_G = - \mathbb{E}_{z\sim p_z}[\log D(G(z))] \end{equation} \]

交叉熵损失在GAN中存在梯度消失/爆炸的问题:

  • 梯度消失: 当判别器太强时,生成器面对 \(D(G(z)) \approx 0\),此时 \(\log(1 - D(G(z))) \approx 0\),梯度趋近于 0。

  • 梯度爆炸: 当判别器太弱时,( D(x) ),梯度可能过大。

WGAN通过Wasserstein距离来替代交叉熵,同时引入Lipschitz条件限制了函数的变化率,进而有效解决了梯度消失/爆炸问题,保证了函数在定义域内任意两点之间的输出值的距离不会超过输入值距离的K倍。这意味着函数不会有过于剧烈的变化,是一种相对 “平滑” 的函数。而通常K值为1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
Class Discriminator_wassertein(nn.Module):

def __init__(self, **kwargs):

super(Discriminator_wassertein, self).__init__()

self.in_size = kwargs['in_size'] # Dim of the random variable to model (PV, wind power, etc)
self.cond_in = kwargs['cond_in'] # Dim of context (weather forecasts, etc)
self.latent_s = kwargs['latent_s'] # Dim of the latent space
self.lambda_gp = kwargs['lambda_gp']

# Set GPU if available
if kwargs['gpu']:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = 'cpu'

l_dis_net = [self.in_size + self.cond_in] + [kwargs['gen_w']] * kwargs['gen_l'] + [1]

# Build the discriminator
alpha = 0.01
self.dis_net = []
for l1, l2 in zip(l_dis_net[:-1], l_dis_net[1:]):
self.dis_net += [nn.Linear(l1, l2), nn.LeakyReLU(alpha)]
self.dis_net.pop() # The last activation function is a ReLU to return a positive number
self.dis_net.append(nn.ReLU())
self.dis = nn.Sequential(*self.dis_net)

def loss(self, generated_samples: torch.Tensor, true_samples: torch.Tensor, context: torch.Tensor):

# Discriminator's answers to generated and true samples
D_true = self.dis(torch.cat((true_samples, context), dim=1))
D_generated = self.dis(torch.cat((generated_samples, context), dim=1))
# Compute Discriminator's loss with a gradient penalty to force Lipschitz condition
gp = self.grad_pen(real=true_samples, samples=generated_samples, context=context)
loss = -(torch.mean(D_true) - torch.mean(D_generated)) + self.lambda_gp * gp

return loss


def forward(self, input: torch.Tensor, context: torch.Tensor):

pred = self.dis(torch.cat((input, context), dim=1))

return pred

def grad_pen(self, real: torch.tensor, samples: torch.tensor, context: torch.Tensor):

# Interpolated sample
bs, sample_size = real.shape[0], real.shape[1]
epsilon = torch.rand((bs, sample_size), device=self.device)
interpolated_sample = real * epsilon + samples * (1 - epsilon)
# Compute critic scores
mixed_score = self.dis(torch.cat((interpolated_sample, context), dim=1))
# Gradient of the mixed_score with respect with the interpolated_sample
gradient = torch.autograd.grad(inputs=interpolated_sample,
outputs=mixed_score,
grad_outputs=torch.ones_like(mixed_score),
create_graph=True, retain_graph=True)[0]

gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_pen = torch.mean((gradient_norm - 1) ** 2)

return gradient_pen
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Generator_linear(nn.Module):

def __init__(self, **kwargs):

super(Generator_linear, self).__init__()
self.in_size = kwargs['in_size'] # Dim of the random variable to model (PV, wind power, etc)
self.cond_in = kwargs['cond_in'] # Dim of context (weather forecasts, etc)
self.latent_s = kwargs['latent_s'] # Dim of the latent space

# Set GPU if available
if kwargs['gpu']:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = 'cpu'

l_gen_net = [self.latent_s + self.cond_in] + [kwargs['gen_w']] * kwargs['gen_l'] + [self.in_size]

# Build the generator
self.gen_net = []
for l1, l2 in zip(l_gen_net[:-1], l_gen_net[1:]):
self.gen_net += [nn.Linear(l1, l2), nn.ReLU()]
self.gen_net.pop() # Regression problem, no activation function at the last layer
self.gen = nn.Sequential(*self.gen_net)

def forward(self, noise: torch.Tensor, context: torch.Tensor):

pred = self.gen(torch.cat((noise, context), dim=1))

return pred

def sample(self, n_s=1, x_cond:np.array=None):

# Generate samples from a multivariate Gaussian
z = torch.randn(n_s, self.latent_s).to(self.device)
context = torch.tensor(np.tile(x_cond, n_s).reshape(n_s, self.cond_in)).to(self.device).float()
scenarios = self.gen(torch.cat((z, context), dim=1)).view(n_s, -1).cpu().detach().numpy()

return scenarios

方法概括

论文采用MMDiT架构,使用流匹配采样方法。

生成器通过加载预训练模型并进行确定性蒸馏,使用均方误差损失,得到一个蒸馏模型:

\[ \hat{v} = \hat{G}(z,c,T) \hat{x} = z - \hat{v} \]

尽管模型的生成结果\(\hat{x}\)是模糊的,但很好避免了直接使用预训练模型导致的模式崩溃,是一种有效的初始化。最后,生成器可以等价为:

\[ G(z,c) := z - \hat{G}(z,c,T) \]

对于鉴别器,使用同样的模型结构,但在第16,26,36层插入交叉注意力层和MLP层,提取特征,并通过拼接、归一化和全连接层得到logits用于对抗目标。结构如图所示。

R1损失需要计算高阶梯度,而现有的加速方法如deepspeed并不支持,论文提出了一种近似R1损失的方案。

0%