概述

生成模型总览(原图来自lilianweng.blog)

Diffusion属于生成模型的一种,相比较于GAN等其他生成模型,Diffusion模型最大的不同之处就在于其latent code是和原输入图相同尺寸的。Diffusion模型其实也可以看成是一个隐变量模型,并且与VAE,GAN的单隐变量不同,其可以看成存在多个隐变量(即加噪过程中的每个加噪结果都可以看成一个隐变量)。Diffusion模型总体包括前向加噪逆向去噪两个过程:

  • 前向过程-加噪扩散:对给定的真实图像不断添加高斯噪声,经过中间状态最终变成纯高斯噪声
  • 逆向过程-去噪生成:从完全的纯噪声不断去噪,经过中间状态最终变成其对应的真实图像

上面两个过程示意图可以表示如下:

Diffusion的前向和逆向过程(原图来自Ho et al.2020)

图中的分别表示前向过程和逆向过程中具体某一步的采样转换过程,这也是后面要重点讨论的两个内容。

前向过程-加噪扩散

前向过程就是对给定的真实图像不断添加高斯噪声,经过中间状态后最终变成纯的高斯噪声,我们可以直接使用高斯噪声用于每一步的加噪并规定每一步的添加的高斯噪声的噪声方差为,而均值则和前一步的结果有关,具体设置为,注意以上的参数和过程都是确定的,还没牵扯到任何神经网络模型或者可学习的参数,那么就有以下两个推导:

  • 前向单步:前向过程中间的某一步()可表示为条件概率: 而且其实可以直观看到整个前向的过程是一个马尔科夫的过程:即当前的时刻的结果()只与其前一个结果()有关而与起始状态()无关,换句话说就是:在仅仅知道前一个状态的条件下就可以得到了,无需其他条件。所以根据马尔科夫链的性质我们也可以得到:这个小细节后面会用到。此外这种预先定义好的方差设计规则我们也叫做Variance shedule,比如DDPM中就采用一个线性变换Variance shedule:

    betas = torch.linspace(start=0.0001, end=0.02, steps=1000)

    而至于这里方差为什么设置在0~1之间,均值系数又为什么设置成,后面会进一步讨论。

  • 总体前向:整个前向过程()可以表达为条件概率下的联合概率:

现在我们思考一下:是通过公式(1)从原始数据一步步推理得到的,那么有没有一个公式能够从直接得到呢?答案是有的!我们在公式(1)的基础上再利用重参数技巧可以从直接推导得到第步的加噪结果,这对于后续Diffusion的理解和推导也具有比较重要的作用,为了方便表达我们再令且令: 【一句题外话:公式(3)从另一个角度看其实可以发现是原始数据和高斯噪声的线性插值,插值系数分别为,满足平方和为1,有的论文也称这两者分别为Signal_rate和Noise_rate】

上述的重参数化不仅使得模型可微,还使得前向过程对的计算并行化,即计算可以跳过、...、直接得到,上面公式(3)的推导结果再根据重参数化反写成如下的概率分布形式:

其中是可以预先计算出来的并且该过程是一个无学习参数的过程,至此我们就完成了"图像噪声"的前向过程

逆向过程-去噪生成

逆向过程就是从完全的纯噪声不断去噪,经过中间状态最终变成其原来对应的真实图像:

回顾一下,前向过程中我们添加的噪声都是已知的,逆向过程我们的总体目标就是希望通过网络预测这些添加的噪声然后一步步去噪声。和前向过程类似,逆向过程我们也写成两种形式:

  • 单步逆向:中间某个逆向过程(从)可表示为条件概率:

  • 总体逆向:整个逆向过程的(从)可以表达为联合概率:

【其实上面的等价于,只不过由前面说说在T足够大时,就接近纯高斯噪声了,那么其实也就与模型无关了,所以可以去掉下标直接变成

其实如果我们知道其真实的逆过程,那么还是能够很容易根据该过程进行逆向去噪,但是问题就在于该真实逆过程依赖于全体所有数据集,所以上面我们才期望通过引入一个额外的采样过程来逼近该过程。公式(5)和公式(6) 的逆向过程基本和前向过程是相对应的,而且采样器也默认是一个高斯分布,前向和逆向两个过程的除了输入输出的区别,其他重要区别则在于逆向过程的采样器是网络学出来的,该网络主要用来预测逆向过程中每一步的用于计算去噪使用的高斯噪声的均值和高斯方差,这两个网络网络均接受当前步骤下尚未去噪的和时间戳为输入。

训练过程

损失函数形式

根据逆向过程中的阐述,Diffusion训练的主要目标就是计算用于预测噪声均值和方差的网络。和其他生成模型一样,优化模型我们可以在真实数据的分布下最大化模型预测分布的对数似然,等价于最小化在下的负对数似然【其实就是最小化的交叉熵】: 对上面优化目标求变分下限(VLB)可以得到:

题外话:这里还可以利用KL散度的非负性配合Fubini定理也可以得到类似的结果: 【这里有个细节要额外解释一下,原本变分下限VLB的定义是指的下限,即,上面公式(8)中的其实是在将VLB取了个负号然后外面又套了一层变成用于最小化的目标损失项】

这里再进一步对进行推导:

【上面最后一步推理其实还是没有特别搞懂,期望的下标不是吗?答:这里暂时没完全搞懂,最后一步我就直接按照lilian.blog里面贴过来了,也是和DDPM论文中完全一致的】,有了解的欢迎在评论区补充~

整理上面的公式可以得到我们要优化的损失函数,其可以写成下面的形式: 其中每一项对应于公式(9)中的每一项有:

损失函数拆解-L_0和L_T

损失函数中的可以用估计的来构建一个离散化的decoder来计算(见DDPM论文的3.3部分)。而中的噪声分布和先验分布都服从标准高斯分布,计算这个KL散度没有训练参数,近似为0。

损失函数拆解-L_{t-1}

重点讨论损函数中的其实就是计算两个分布的KL散度,其中由公式(5)已经知道是一个高斯分布,现在我们来推理的形式:

这时候我们发现上面存在函数的,我们再根据高斯分布的函数表达形式: 联想到是不是可以认为就是一个高斯分布,而该高斯分布是在条件下的,所以其均值和方差也与有关,我们将其分别定义为,所以我们就得到了的表达式: 现在我们就来求它的均值和方差,然后我们回过头看一下公式(11)和公式(12)中二次项(红色部分)、一次项(绿色部分),然后通过配平方可以比较容易得到的均值和方差【这里对可以稍微化简一下表达为】:

通过上面我们其实可以看到方差是一个和无关的常数项。

上面推了这么多公式,可能会有点忘记最初的目标,到此我们回过头再捋一下思路:Diffsion训练的主要目标就是计算用于预测噪声均值和方差的网络,然后具体通过最小化的交叉熵来优化模型,然后通过计算该交叉熵的变分下限得到了优化损失,计算得到 该损失在抛开两个常数项后其实主要就是由多个KL散度组成,而其中由公式(5)可知其是一个高斯分布且通过公式(11)的推导知道其也是一个高斯分布,那么其实我们整个Diffusion模型训练的损失就是通过KL散度来拉近若干个高斯分布pair对:【一口气讲完,呼~顺畅了】

【其实换个角度来说,就可以看成是前向过程的后验概率分布,即可以理解为前向过程的真实逆向过程(当然这个过程我们无法显示建模),上面公式(15)也可以看成是我们希望有一个网络能够去预测导致真实的去噪过程】。

计算高斯分布对的KL散度

在计算公式(15)中的高斯分布pair对的KL散度前,DDPM论文中做了进一步的简化,即采用固定的方差,即令,而这里的可以直接设定为。(这其实是两个极端,分别是上限和下限,也可以采用可训练的方差,见论文Improved Denoising Diffusion Probabilistic ModelsAnalytic-DPM: )。这里直接假定

根据上面固定方差的前提,那么我们就可以进一步简化公式(15): 至此再根据多元高斯分布的KL散度可以计算得到:

【其中第三行用到了对角矩阵的逆等于该对角矩阵对角线上每个元素取倒数】

由公式(17)我们其实可以发现了!!在固定方差的情况下,原来公式(5)中我们需要得到的网络就是希望其能够尽可能地接近那么其实我们就不妨就直接令两者相等。

到这里再捋一次思路:通过上面一系列式子的推导可以看到我们需要优化的总体损失函数在抛开近乎常数项后,最小化的目标其实就是希望预测的噪声均值和前向过程中对应步骤真实使用的噪声均值是一致的,那么我们在观察公式(14-3)中的形式后,我们可以直接先定义具有相同的公式形式,并且额外引入一个噪声预测网络希望该网络输出的噪声能够尽可能接近公式(14-3)的噪声(这里再稍微回顾下:根据一开始的公式(3)知道其实是从直接计算使用的噪声)这样就能保证尽可能地相等,于是可以得到: 再将公式(14-3)和公式(18)带入公式(17)得到:

从上面其实就可以看到我们在固定方差预测后,我们的目标就由公式(17)的预测噪声均值一致性变成了预测噪声一致性。将公式(9)的期望移动到后,则DDPM对上式进行了进一步简化(即去掉了权重系数)后最终的我们Diffusion的模型的优化目标就变成了:

从DDPM的对比实验结果来看,预测噪声也确实比预测均值的效果要好,并且采用简化版本的优化目标比原VLB目标效果要好。根据公式(3)中的关系,其实可以在公式(20)中进一步将所有的换成,那么模型就又从预测噪声一致变成了预测一致。但是在DDPM中也说明了预测一致的效果会相对差一点。

此外我们回头看一下公式(5),其中的均值预测网络由公式(18)可以获得,而方差则通过上面的固定化,那么我们就可以从得到的推导结果(相当于完成了一步去噪): 不断迭代公式(21)就完成了一步步"噪声图像"的去噪过程。

总结

根据公式(14)总结一下Diffusion的训练过程:

  • 随机选择一个训练样本,从的前向过程随机选择一个时间戳t,从高斯噪声中采样随机噪声
  • 计算第步加噪的结果
  • 和时间戳t送入到噪声预测网络得到预测噪声
  • 根据预测的噪声和真实噪声计算L2/L1损失并进行梯度下降

网络训练完后,可以进行逆向过程进行去噪生成图像:

  • 从高斯噪声中采样随机噪声
  • 根据公式(21)可以进行逆向过程从预测得到得到去噪后的
  • t从T到1不断重复上面过程,直至最后完成新样本的生成
DDPM训练和采样过程(原图来自Ho et al.2020)

训练过程示意图可以表示为下面这样:

DDPM训练过程可视化(原图来自zhihu.answer)

思考和QA环节

以下思考和回答都是我在推导DDPM公式的时候突然想到的,回答也是基于自己的理解进行解释,如有问题,欢迎指出~

  • Q: 公式(1)中均值系数直接设置为,为什么要和方差系数扯上关系?:
    A: 通过公式(3)或者公式(4)可以看到在设定了均值系数为在0~1范围后,当前向过程足够久的时候,即时,均值,方差,这时候就保证的分布收敛于一个标准的高斯分布。同时我也推理了一下如果就是普通的均值,那么在T=3时候有: 可以看到均值系数是连乘的形式,这只需要每一项就能保证连乘的结果趋近于0即均值系数为0了,对于方差部分在均值系数设置为是一个可以保证时方差收敛于1的解,至于还有没有其他解法:cry:本人暂时有限,也欢迎大家评论补充。

  • Q:公式(20)直接将每个KL散度的系数去掉了,会不会有什么副作用?
    A:由于去掉了不同的权重系数,所以这个简化的目标其实是对原本损失中进行了reweight。对于这种reweight,先将系数随着时间的变化函数画出来如下:其中按照DDPM论文中取在范围均匀采样1000次(如下图所示),从下图可以看到虽然不是严格单调递减,但是总体趋势是:随着t的增加,的权重系数越来越小,所以DDPM论文中的说法是:reweight相比较而言降低了在较小时候的损失权重而增加了在较大时候的损失权重(因为全部都变成了1),而在较小的时候(即已经接近原图时),这时候去噪使用的噪声已经很少了,所以降低他们的权重对于整体的网络训练是有好处的,而在较大的时候(即降噪初期),降噪任务相对较难,reweight也能够使网络更加专注于t较大时候的去噪。

    L_t权重系数变化过程
  • Q:根据公式 (3)知道是从直接计算使用的噪声,而根据公式(5)知道其是在从预测使用的噪声,虽然都是服从标准高斯分布,但是从数值上看应该是两个不同的噪声吧?两个能直接计算L2损失?
    A:表扬下自己,感觉思考还是比较深入。我的理解如下: 也确实是从直接计算使用的噪声,但是回顾一下的定义,其就是为了单纯拟合的并没有什么有实际含义的物理定义,如果非要有什么明确的物理含义,那么可以认为的目的就是去预测从计算所使用的噪声

  • Q:既然就是去预测从计算所使用的噪声,那么这里为什么这里能够直接在的时候使用这个噪声进行去噪?
    A:我想的是这里只是用到了这个变量,而不仅仅就是简单地用减去去噪得到了,其实仔细想想,是为了计算用到的噪声的均值-即公式(14-3)的,通过这个这个噪声均值和方差再去得到的,至于为什么能通过预测的的噪声来计算的噪声均值,那就是公式(18)做的事情。

  • Q:有一个可能比较蠢但是对于理解含隐变量生成模型比较有帮助,就是如果我们的损失函数是为了保证网络预测的噪声/均值和原本前向过程使用的噪声/均值完全一致,而前向的噪声是我们自己指定的,肯定是已知的,那么在反向生成的过程中直接就用原始前向过程中实际用到的噪声不就能去噪了吗?
    A:首先肯定的是:直接用前向的噪声去噪,肯定是能恢复得到原始输入图像的(其实就是加一减一等于自己的过程),但是问题就在这里了,这时候我们也就只能去噪得到训练集中的原始输入图像了,无法生成其他图像,因为我们想生成其他不存在的图像,但是现在我们拿到手的只有最原始的噪声啊,我们不知道他在前向过程中是怎么加噪得到的,所以我们这才需要通过额外训练好的网络来"想象"它被添加了哪些噪声,然后进行去噪。其实想一下VAE的过程就知道了,VAE的目标其实就是重构原图,其将一个原图映射到一个高斯分布的隐空间,然后再从隐空间回到图像空间,看起来也是一个加一减一的相反过程,但是其能保证在训练完成后我们在隐空间随便采样(尤其是对于其原来没见过的样本)时,都能够用原本VAE的能力生成对应的图像。所以相当于来说VAE是学到了从隐空间映射到图像空间的能力,而Diffusion是学到了从采样噪声中预测噪声的能力来去噪生成图像(采样噪声-预测噪声=生成图像),两者的相似性在于VAE中的encoder对应Diffusion的前向过程,VAE中的decoder对应Diffusion中的逆向过程。

✅2023.8.25补充:

  • Q:DDPM代表的Diffusion模型为什么要迭代优化,不能一次性添加噪声一次性预测噪声吗?

    DDPM的加噪和去噪都是一个马尔科夫链的过程,加噪是在上一步的结果上继续添加噪声,并且在上述公式(4)也表明了只有在添加噪声次数足够多(T足够大)的情况下,最终的加噪结果才是符合高斯分布的。那么问题又来到“我们为什么要将噪声变成纯高斯噪声?我们一次性在原图上加大幅度的噪声不就让原图变得噪乱不堪了吗?”这没错,但是你知道这时候的噪声图像是什么分布吗?如果不知道的话,我们怎样在训练完模型后采样这个原始噪声进行去噪生成呢?没办法了啊!所以为了好采样,我们要求噪声符合高斯分布,而让一张原始图变成高斯分布,又需要通过马尔科夫链的形式将其一步步变成纯高斯噪声。

    其实一步步添加小噪声的另一个好处是降低逆向过程中噪声预测的难度,提升生成效果!

Reference