Diffusion学习笔记

3 minute read

Published:

扩散模型学习过程中一些文章的阅读笔记以及模型的相关内容的简单整理

一般框架SDE

扩散模型的前向过程是由随机微分方程(SDE)描述的线性扩散

dx=Ftxdt+Gtdω

其中,MN为输入数据的维度,FtRM×NGtRM×N分别为漂移系数和扩散系数,ω为标准的为维纳过程

目前对SDE的研究较为成熟,前向过程(1)对应的反向过程可由式(2)所示的SDE函数族表示:

dx=[Ftx1+λ22GtGtlogpt(x)]dt+λGtdω

其中,λ0,当λ=1时,即为式(3)对应的SDE反向过程

dx=[FtxdtGtGtlogpt(x)]+Gtdω

而当λ=0时,反向过程中的方差为0,SDE退化为概率流ODE。

在实际应用中,通过对网络训练得到logpt(x)的近似值sθ(xt,t),借助离散化的方式对式(2)实现数值求解,完成扩散模型的反向过程

拟合误差

概率流ODE

概率流ODE为确定性的微分方程,

dx=[Ftx12GtGtlogpt(x)]dt

Euler解法

ˆxtΔt=ˆxt[Ftxt+12GtGtLtϵθ(xt,t)]Δt

指数积分(EI)解法

xtΔt=etΔttFτdτxt+tΔtt12etΔtτFrdrGτGτLtϵθ(xτ,τ)dτ

DDIM

在DDIM中,作者提出,p(x1:t)分解为马尔可夫过程不是必须的,只要保证p(xt|x0)p(xt1|xt,x0)与DDPM是相同的就可以得到DDPM的等价模型。文中,作者将p(xt1|xt,x0)的分布定义为(7)

p(xt1|xt,x0)N(xt1;ˉαt1 x0+ˉβt1σ2txtˉαt x01ˉαt,σ2tI)

其中,

σt=ηˉβt1βtˉβt

η为可调节的参数,使用网络估计得到的

$\boldsymbol{\epsilon}{\theta}\left(\boldsymbol{x}{t}, t\right)$

替换

x0

得到

xt1=1αt(xt(ˉβtαtˉβt1σ2t)ϵθ(xt,t))+σtϵ

DPM-Solver

半线性公式

DPM-Solver提出,ODE方程是一个半线性的方程,线性项ftxt是可以准确的计算的,之前的采样算法忽略了这一点,从而导致对ODE方程的数值解法会产生较大的拟合误差,算法的加速性能不好。因此,在DPM-Solver中,作者将线性项与非线性项分开,对非线性项采用数值解法,从而减少了拟合误差。

引入参数$\lambda_{t} = log({\bar{\alpha}{t}} / {\sigma{t}})g^2(t)\lambda_t$的函数:

g2(t)=ˉα2tddt(σ2tˉα2t)=2σ2dλtdt

对式(6)进行参数替换,同时代入VP-SDE前向过程对应的参数,得到:

xt=ˉαtˉαsxsˉαtts(dλτdτ)στˉατϵθ(xτ,τ)dτ

考虑到λt为前向过程中信噪比的一半,是严格单调递减的,因此存在一个函数tλ()使得t=tλ(λ(t)),因此,对式(11)进行变量替代后,得到:

xt=ˉαtˉαsxsˉαtλtλseλϵθ(xλ,λ)dλ

(12)给出了ODE解法的新视角——只需要对指数积分项进行估计,从而避免了估计线性项带来的误差。

数值估计

在对非线性项的估计中,DPM-Solver对$\boldsymbol{\epsilon}{\theta}\left(\boldsymbol{x}{\lambda}, \lambda\right)(13)(14)$:

ϵθ(xλ,λ)=k1n=0(λλti)nn!ϵ(n)θ(xλti,λti)+O((λλti)k) xtiti1=αti1αtixtiαti1k1n=0ϵ(n)θ(xλti,λti)λti1λtieλ(λλti)nn!dλ+O((λλti)k+1)

其中,$\boldsymbol{\epsilon}{\theta}^{(n)}\left(\boldsymbol{x}{\lambda_{t_{i}}}, \lambda_{t_{i}}\right)\boldsymbol{\epsilon}{\theta}\left(\boldsymbol{x}{\lambda_{t_{i}}}, \lambda_{t_{i}}\right)nk\int e^{-\lambda} \frac{(\lambda - \lambda_{t_{i}})^{n}}{n!} \mathrm{d}\lambda\boldsymbol{\epsilon}{\theta}\left(\boldsymbol{x}{\lambda_{t_{i}}}, \lambda_{t_{i}}\right)$的导数,即可实现对非线性部分的估计,而对于其导数的估计已经在现有的文章中有较好的研究。

考虑到k比较大时需要引入过多的中间点来进行导数的估计,因此作者只使用了k=1,2,3,三种不同阶数的DPM-Solver。

与DDIM联系

DDIM是较早提出的确定性采样算法,但是一直没有较好的理论将其与ODE联系起来,将λti带入到一阶DPM-Solver(式(15))中可以得到DDIM对应的微分表达式(待引用DDIM)

xtiti1=αti1αtixtiαti1ϵθ(xλti,λti)(eλtieλti1)

因此,DDIM可以看作是DPM-Solver的一种特殊情况,由于充分利用了半线性的特点,因此DDIM相比于传统的Euler数值解法,具有更好的性能。

DEIS

Diffusion Exponential Integrator Sampler (DEIS)同样利用了ODE方程的半线性的性质,该方法与DPM-Solver最本质的区别在于对非线性项的估计中使用端点处的ϵθ(xti,ti)代替积分区间内的ϵθ(xt,t),同时构建r阶$\boldsymbol{P}{r}(t)(16)\epsilon{\theta}(x_{t}, t)$估计的估计误差:

Pr(t)=rj=0[kjtti+jti+jti+k]ϵθ(xti+j,ti+j)

因此,DEIS的采样过程为:

xtiti1=Ψ(ti1,ti)xti+rj=0ti1ti12Ψ(ti1,τ)GτGτLτ[kjτti+jti+jti+k]ϵθ(xti+j,ti+j)dτ

其中,

Ψ(ti1,ti)=eti1tiFτdτ

同时,DEIS还提出利用$\boldsymbol{y}{t} = \Psi(0, t)\boldsymbol{x}{t}$进行参数替换,消除ODE方程的非线性,从而使现有成熟的ODE数值解法具有更好的表现。

拟合误差

模型训练

方差估计

Analytic-DPM

Σ(xt)=Ex0p(x0|xt)[(x0ˉμ(xt))(x0ˉμ(xt))]=Ex0p(x0|xt)[((x0xtˉαt)+ˉβtˉαtϵθ(xt,t))((x0xtˉαt)+ˉβtˉαtϵθ(xt,t))]=Ex0p(x0|xt)[(x0xtˉαt)(x0xtˉαt)]ˉβtˉαtϵθ(xt,t)ϵθ(xt,t)=1ˉαtEx0p(x0|xt)[(xtˉαtx0)(xtˉαtx0)]ˉβtˉαtϵθ(xt,t)ϵθ(xt,t) Extp(xt)Ex0p(x0|xt)[(xtˉαtx0)(xtˉαtx0)]=Ex0p(x0)Extp(xt|x0)[(xtˉαtx0)(xtˉαtx0)] ˉσ2t=ˉβtˉαt(11dExtp(xt)[ϵθ(xt,t)2])

SN-DPM

Σ(xt)=Ex0p(x0|xt)[(x0ˉμ(xt))(x0ˉμ(xt))]=1ˉαtEx0p(x0|xt)[(xtˉαtx0)(xtˉαtx0)]ˉβtˉαtϵθ(xt,t)ϵθ(xt,t)=ˉβtˉαtEx0p(x0|xt)[ϵθ(xt,t)ϵθ(xt,t)]ˉβtˉαtϵθ(xt,t)ϵθ(xt,t)

NPR-DPM

Σ(xt)=Ex0p(x0|xt)[(x0ˉμ(xt))(x0ˉμ(xt))]=ˉβtˉαtEx0p(x0|xt)[(ϵtϵθ(xt,t))(ϵtϵθ(xt,t))]

乔列斯基(Cholesky)分解

##

Latent Diffusion Model(LDM)

LDM在原本的DDPM的基础上使用预训练的VAE将输入压缩到潜空间,模型被训练用来生成图像在潜空间的表示。

VAE编码

潜空间训练

Blurring Diffusion Model(BDM)

BDM利用DCT将模型定义在了频率空间,对图像在频率空间的表征进行扩散模型的训练,令 $\boldsymbol{u}{t} = \boldsymbol{V}^{\top}\boldsymbol{x}{t}\boldsymbol{u}{\boldsymbol{\epsilon},t} = \boldsymbol{V}^{\top}\boldsymbol{\epsilon}{t}\boldsymbol{V}^{\top}DCT(24)$所示:

ut=αtut+σtuϵ,t

同时,由于对噪声的估计在标准的像素空间表现更好,因此在使用网络去噪时,使用逆变换将频率空间内的图像表示转换到像素空间输入网络进行噪声预测,如式(25)所示:

L:=ϵθ(zt,t)ϵt2

其中,

$\boldsymbol{z}{t} = \boldsymbol{V}(\boldsymbol{\alpha}{t}\boldsymbol{u}{t} + \boldsymbol{\sigma}{t} \boldsymbol{u}_{\boldsymbol{\epsilon},t})$

V

表示DCT逆变化,在频率空间的采样过程与原DDPM保持相同。