【PAPER-introduction】DDPM
Personal study notes of work DDPM.
DDPM逐步学习
前向过程
x0->xt
- x0真实图像
- xt最终噪声
- q(xt|xt-1)以前一个状态为均值的高斯分布,从这里采xt
- 上面一行的分布方差为beta t (t属于1-T)
- T步数、beta 1和beta t是需要设置的
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
- linspace()生成随机数
但是其实xt并不是直接通过q(xt|xt-1)得到的,还要经过重参数化
重参数化介绍TODO
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
noise = torch.randn_like(x_0)
xt = sqrt(1-betas[t] * xt-1 + sqrt(betas[t])*noise) # 均值和方差开根乘噪声
torch.randn_like()返回与输入大小相同的张量,填满了均值为0方差为1的随机数
前向过程中可以取到任意一个xt,即从这个分布q(xt|x0)采样
# 初始设置
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
alphas = 1-betas
alphas_cum = torch.cumprod(alphas, 0) # 累乘得到alphat均值
alphas_cum_s = torch.sqrt(alphas_cum) # 根号alphat
alphas_cum_sm = torch.sqrt(1-alphas_cum) # 1-alphat
# 重参数化得到xt
noise = torch.randn_like(x_0)
xt = alphas_cum_s[t] * x0 + alphas_cum_sm * noise) # 均值和方差开根乘噪声
后向过程
反向的生成q(xt-1|xt,x0)是前向过程q(xt|xt-1)的后验概率分布
可以根据贝叶斯公式推导得到,并且它也是一个高斯分布,并且方差是个常量
# 伪代码
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
alphas = 1-betas
alphas_cum = torch.cumprod(alphas, 0) # 累乘得到alphat均值
alphas_cum_prev = torch.cat((torch.tensor([1.0]),alphas_cum[:-1]),0)
# torch.cat将两个矩阵连接起来
posterior_variance = betas * (1 - alphas_cum_prev) / (1 - alphas_cum)
训练目标
生成的时候我们并没有x0,那么我们要如何得到q(xt-1|xt,x0)呢?
这时候我们构造一个分布p(xt-1|xt)
它的方差和q的一致
均值则更改,将x0改为x0(xt,t),一个神经网络Unet,输入噪声图像xt和时间步t
通过一个损失函数缩小q和p之间的差距
最后模型无法直接通过p来生成,因为变量太多
但是,我们根据前向过程知道xt可以由x0获得,因此变换一下,就可以得到有xt得x0的式子
代入到q(xt-1|xt,x0)的均值式子我们可以得到q(xt-1|xt,x0)的均值只与xt和前向过程中t时刻加的噪声有关
所以我们修改p(xt-1|xt)的均值,将模型修改为根据xt和t预测t时刻添加的噪声,就得到了一个常量比较多的目标函数
训练伪代码如下
betas = torch.linspace(start=0.0001, end=0.02,steps=1000)
alphas = 1-betas
alphas_cum = torch.cumprod(alphas, 0) # 累乘得到alphat均值
alphas_cum_s = torch.sqrt(alphas_cum) # 根号alphat
alphas_cum_sm = torch.sqrt(1-alphas_cum) # 1-alphat
def diffusion_loss(model, x0, t, noise):
# 根据公式计算xt
xt = alphas_cum_s[t] * x0 + alphas_cum_sm[t] * noise
# Unet预测噪声
predicted_noise = model(xt, t)
# 计算loss
return mse_loss(predicted_noise, noise)
for i in len(data_loader):
# 从数据集里读取一个batch,是真实图片
x0 = next(data_loader)
# 采样时间步
t = torch.randint(0,1000,(batch_size,))
# 生成高斯噪声
noise = torch.randn_like(x_0)
loss = diffusion_loss(model, x0, t, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
采样过程
真实的推理过程必须从T步开始向前逐步生成图片
先生成一个噪声图片,每个时间步t将xt传入模型预测噪声,然后采样一个噪声
由重参数化技巧和后验概率的均值方差,可以计算xt-1,最后直到T=1
改进
- Improved DDPM对beta进行了改进,使得前向过程加噪声的过程更合理了一点
- Improved DDPM还指出可以通过直接设置更小的时间步s来减小采样的时间,具体是把s与原本的t对应
参考
[1]微信公众号 GiantPandaCV 一文弄懂Diffusion Model