【PAPER-introduction】DDPM

Personal study notes of work DDPM.

DDPM逐步学习

前向过程

x0->xt

  • x0真实图像
  • xt最终噪声
  • q(xt|xt-1)以前一个状态为均值的高斯分布,从这里采xt
  • 上面一行的分布方差为beta t (t属于1-T)
  • T步数、beta 1和beta t是需要设置的
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
  • linspace()生成随机数
    但是其实xt并不是直接通过q(xt|xt-1)得到的,还要经过重参数化
    重参数化介绍TODO
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
noise = torch.randn_like(x_0)
xt = sqrt(1-betas[t] * xt-1 + sqrt(betas[t])*noise)  # 均值和方差开根乘噪声

torch.randn_like()返回与输入大小相同的张量,填满了均值为0方差为1的随机数

前向过程中可以取到任意一个xt,即从这个分布q(xt|x0)采样

# 初始设置
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
alphas = 1-betas
alphas_cum = torch.cumprod(alphas, 0) # 累乘得到alphat均值
alphas_cum_s = torch.sqrt(alphas_cum)  # 根号alphat
alphas_cum_sm = torch.sqrt(1-alphas_cum)  # 1-alphat

# 重参数化得到xt
noise = torch.randn_like(x_0)
xt = alphas_cum_s[t] * x0 + alphas_cum_sm * noise)  # 均值和方差开根乘噪声

后向过程

反向的生成q(xt-1|xt,x0)是前向过程q(xt|xt-1)的后验概率分布
可以根据贝叶斯公式推导得到,并且它也是一个高斯分布,并且方差是个常量

# 伪代码
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
alphas = 1-betas
alphas_cum = torch.cumprod(alphas, 0) # 累乘得到alphat均值
alphas_cum_prev = torch.cat((torch.tensor([1.0]),alphas_cum[:-1]),0)
# torch.cat将两个矩阵连接起来
posterior_variance = betas * (1 - alphas_cum_prev) / (1 - alphas_cum)

训练目标

生成的时候我们并没有x0,那么我们要如何得到q(xt-1|xt,x0)呢?
这时候我们构造一个分布p(xt-1|xt)
它的方差和q的一致
均值则更改,将x0改为x0(xt,t),一个神经网络Unet,输入噪声图像xt和时间步t
通过一个损失函数缩小q和p之间的差距

最后模型无法直接通过p来生成,因为变量太多
但是,我们根据前向过程知道xt可以由x0获得,因此变换一下,就可以得到有xt得x0的式子
代入到q(xt-1|xt,x0)的均值式子我们可以得到q(xt-1|xt,x0)的均值只与xt和前向过程中t时刻加的噪声有关
所以我们修改p(xt-1|xt)的均值,将模型修改为根据xt和t预测t时刻添加的噪声,就得到了一个常量比较多的目标函数

训练伪代码如下

betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
alphas = 1-betas
alphas_cum = torch.cumprod(alphas, 0) # 累乘得到alphat均值
alphas_cum_s = torch.sqrt(alphas_cum)  # 根号alphat
alphas_cum_sm = torch.sqrt(1-alphas_cum)  # 1-alphat

def diffusion_loss(model, x0, t, noise):
    # 根据公式计算xt
    xt = alphas_cum_s[t] * x0 + alphas_cum_sm[t] * noise
    # Unet预测噪声
    predicted_noise = model(xt, t)
    # 计算loss
    return mse_loss(predicted_noise, noise)
   
for i in len(data_loader):
    # 从数据集里读取一个batch,是真实图片
    x0 = next(data_loader)
    # 采样时间步
    t = torch.randint(0,1000,(batch_size,))
    # 生成高斯噪声
    noise = torch.randn_like(x_0)
    loss = diffusion_loss(model, x0, t, noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

采样过程

真实的推理过程必须从T步开始向前逐步生成图片
先生成一个噪声图片,每个时间步t将xt传入模型预测噪声,然后采样一个噪声
由重参数化技巧和后验概率的均值方差,可以计算xt-1,最后直到T=1

改进

  • Improved DDPM对beta进行了改进,使得前向过程加噪声的过程更合理了一点
  • Improved DDPM还指出可以通过直接设置更小的时间步s来减小采样的时间,具体是把s与原本的t对应

参考

[1]微信公众号 GiantPandaCV 一文弄懂Diffusion Model