Diffusion Adversarial Post-Training for One-Step Video Generation

摘要

现有蒸馏方法展示了图像领域中一步生成的潜力,但它们仍然遭受严重的质量下降。论文提出了对抗后训练(Adversarial Post-Training, APT)来进行一步视频生成。为了提高训练稳定性和质量,论文改进了模型架构和训练过程,以及一种R1近似正则化目标。

GAN回顾

生成对抗网络(Generative Adversarial Networks)是生成模型的一种,由生成器(Generator)和判别器(Discriminator)两个模块组成。生成器的训练目标是最小化生成数据分布和真实数据分布的差异;而判别器的目标是最大化生成数据分布和真实数据分布之间的差异(损失),在实际设计中,一般会取负数来实现最小化的优化目标。

在传统的 GAN 中,判别器使用交叉熵损失来区分真实样本和生成样本,这种方法容易导致训练不稳定,如梯度消失或梯度爆炸。

GAN的优化目标:

\[ \begin{equation} \min_G \max_D \mathbb{E}_{x \sim p_{\textup{data}}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))] \end{equation} \]

我们希望判别器对于真实样本的输出尽可能接近1,即log 1 = 0; 同时希望判别器对于生成样本的输出尽可能接近0,即log (1-D(G(z))) = 0。

这符合交叉熵损失的逻辑,假设真实样本为1,生成样本为0,判别器的目标就是尽可能预测生成样本的概率为0,,预测真实样本的概率为1。

\[ \begin{equation} {\cal L}_D = - \mathbb{E}_{x \sim p_{\rm data}} [\log D(x)] - \mathbb{E}_{z \sim p_z} [\log (1-D(G(z)))] \end{equation} \]

这里期望就是求和求平均,等价于优化目标,最小化负对数似然即是最大化对数似然。

对于生成器,优化目标就是欺骗判别器,使得log D(G(z))尽可能接近1,即最大化对数似然,但一般简便起见,优化为以下形式:

\[ \begin{equation} {\cal L}_G = - \mathbb{E}_{z\sim p_z}[\log D(G(z))] \end{equation} \]

交叉熵损失在GAN中存在梯度消失/爆炸的问题:

  • 梯度消失: 当判别器太强时,生成器面对 \(D(G(z)) \approx 0\),此时 \(\log(1 - D(G(z))) \approx 0\),梯度趋近于 0。

  • 梯度爆炸: 当判别器太弱时,( D(x) ),梯度可能过大。

WGAN通过Wasserstein距离来替代交叉熵,同时引入Lipschitz条件限制了函数的变化率,进而有效解决了梯度消失/爆炸问题,保证了函数在定义域内任意两点之间的输出值的距离不会超过输入值距离的K倍。这意味着函数不会有过于剧烈的变化,是一种相对 “平滑” 的函数。而通常K值为1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
Class Discriminator_wassertein(nn.Module):

def __init__(self, **kwargs):

super(Discriminator_wassertein, self).__init__()

self.in_size = kwargs['in_size'] # Dim of the random variable to model (PV, wind power, etc)
self.cond_in = kwargs['cond_in'] # Dim of context (weather forecasts, etc)
self.latent_s = kwargs['latent_s'] # Dim of the latent space
self.lambda_gp = kwargs['lambda_gp']

# Set GPU if available
if kwargs['gpu']:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = 'cpu'

l_dis_net = [self.in_size + self.cond_in] + [kwargs['gen_w']] * kwargs['gen_l'] + [1]

# Build the discriminator
alpha = 0.01
self.dis_net = []
for l1, l2 in zip(l_dis_net[:-1], l_dis_net[1:]):
self.dis_net += [nn.Linear(l1, l2), nn.LeakyReLU(alpha)]
self.dis_net.pop() # The last activation function is a ReLU to return a positive number
self.dis_net.append(nn.ReLU())
self.dis = nn.Sequential(*self.dis_net)

def loss(self, generated_samples: torch.Tensor, true_samples: torch.Tensor, context: torch.Tensor):

# Discriminator's answers to generated and true samples
D_true = self.dis(torch.cat((true_samples, context), dim=1))
D_generated = self.dis(torch.cat((generated_samples, context), dim=1))
# Compute Discriminator's loss with a gradient penalty to force Lipschitz condition
gp = self.grad_pen(real=true_samples, samples=generated_samples, context=context)
loss = -(torch.mean(D_true) - torch.mean(D_generated)) + self.lambda_gp * gp

return loss


def forward(self, input: torch.Tensor, context: torch.Tensor):

pred = self.dis(torch.cat((input, context), dim=1))

return pred

def grad_pen(self, real: torch.tensor, samples: torch.tensor, context: torch.Tensor):

# Interpolated sample
bs, sample_size = real.shape[0], real.shape[1]
epsilon = torch.rand((bs, sample_size), device=self.device)
interpolated_sample = real * epsilon + samples * (1 - epsilon)
# Compute critic scores
mixed_score = self.dis(torch.cat((interpolated_sample, context), dim=1))
# Gradient of the mixed_score with respect with the interpolated_sample
gradient = torch.autograd.grad(inputs=interpolated_sample,
outputs=mixed_score,
grad_outputs=torch.ones_like(mixed_score),
create_graph=True, retain_graph=True)[0]

gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_pen = torch.mean((gradient_norm - 1) ** 2)

return gradient_pen
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Generator_linear(nn.Module):

def __init__(self, **kwargs):

super(Generator_linear, self).__init__()
self.in_size = kwargs['in_size'] # Dim of the random variable to model (PV, wind power, etc)
self.cond_in = kwargs['cond_in'] # Dim of context (weather forecasts, etc)
self.latent_s = kwargs['latent_s'] # Dim of the latent space

# Set GPU if available
if kwargs['gpu']:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = 'cpu'

l_gen_net = [self.latent_s + self.cond_in] + [kwargs['gen_w']] * kwargs['gen_l'] + [self.in_size]

# Build the generator
self.gen_net = []
for l1, l2 in zip(l_gen_net[:-1], l_gen_net[1:]):
self.gen_net += [nn.Linear(l1, l2), nn.ReLU()]
self.gen_net.pop() # Regression problem, no activation function at the last layer
self.gen = nn.Sequential(*self.gen_net)

def forward(self, noise: torch.Tensor, context: torch.Tensor):

pred = self.gen(torch.cat((noise, context), dim=1))

return pred

def sample(self, n_s=1, x_cond:np.array=None):

# Generate samples from a multivariate Gaussian
z = torch.randn(n_s, self.latent_s).to(self.device)
context = torch.tensor(np.tile(x_cond, n_s).reshape(n_s, self.cond_in)).to(self.device).float()
scenarios = self.gen(torch.cat((z, context), dim=1)).view(n_s, -1).cpu().detach().numpy()

return scenarios

方法概括

论文采用MMDiT架构,使用流匹配采样方法。

生成器通过加载预训练模型并进行确定性蒸馏,使用均方误差损失,得到一个蒸馏模型:

\[ \hat{v} = \hat{G}(z,c,T) \hat{x} = z - \hat{v} \]

尽管模型的生成结果\(\hat{x}\)是模糊的,但很好避免了直接使用预训练模型导致的模式崩溃,是一种有效的初始化。最后,生成器可以等价为:

\[ G(z,c) := z - \hat{G}(z,c,T) \]

对于鉴别器,使用同样的模型结构,但在第16,26,36层插入交叉注意力层和MLP层,提取特征,并通过拼接、归一化和全连接层得到logits用于对抗目标。结构如图所示。

R1损失需要计算高阶梯度,而现有的加速方法如deepspeed并不支持,论文提出了一种近似R1损失的方案。