扩散模型实战(五):采样过程
作者:weixin01 · 2024-12-09 · 阅读时间:4分钟
在扩散模型实战(四):从零构建扩散模型文章中已经介绍了在原始数据集MNIST中添加噪声以及基于基本的UNet网络训练扩散模型,模型已经可以进行预测,但是发现输入数据噪声量很大的时候预测的效果并不好,如下图所示:

那如何改进呢?
其实思路比较简单,就是按照预测的方向多迭代几次就可以,比如我们从完全的随机数开始按照上述思路进行扩散,下面是实现的代码:
# 采样策略:把采样过程拆解为5步,每次只前进一步
n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # 从完全随机的值开始
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad(): # 在推理时不需要考虑张量的导数
pred = net(x) # 预测“去噪”后的图像
pred_output_history.append(pred.detach().cpu())
# 将模型的输出保存下来,以便后续绘图时使用
mix_factor = 1/(n_steps - i) # 设置朝着预测方向移动多少
x = x*(1-mix_factor) + pred*mix_factor # 移动过程
step_history.append(x.detach().cpu()) # 记录每一次移动,以便后续
# 绘图时使用
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])
[0].clip(0, 1), cmap='Greys')
axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_
history[i])[0].clip(0, 1), cmap='Greys')
我们执行5次迭代,观察一下模型预测的变化,输出结果如下图所示:

从上图可以看出,模型在第一步就已经输出了去噪的图片,只是往最终的目标前进了一小步,效果不佳,但是迭代5次以后,发现效果越来越好。如果迭代更多次数,效果如何呢?
# 将采样过程拆解成40步
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_
steps))# 将噪声量从高到低移动
with torch.no_grad():
pred = net(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)
[0].clip(0, 1), cmap='Greys')

从上图可以看出,虽然在迭代多次以后,生成的图像越来越清晰,但是最终的效果仍然不是很好,我们可以尝试训练更长时间的扩散模型,并调整模型参数、学习率、优化器等。
文章转自微信公众号@ArronAI
热门推荐
一个账号试用1000+ API
助力AI无缝链接物理世界 · 无需多次注册
3000+提示词助力AI大模型
和专业工程师共享工作效率翻倍的秘密
热门API
- 1. AI文本生成
- 2. AI图片生成_文生图
- 3. AI图片生成_图生图
- 4. AI图像编辑
- 5. AI视频生成_文生视频
- 6. AI视频生成_图生视频
- 7. AI语音合成_文生语音
- 8. AI文本生成(中国)
最新文章
- GPT-4o API全攻略:多模态AI模型的功能解析与实战指南
- Python 使用 话费 API:轻松实现自动话费查询功能
- 构建现代RESTful API:C#中的关键标准和最佳实践
- 优化 ASP.NET Core Web API 性能方法
- 如何设计一个对外的安全接口?
- 2025 LangGraph AI 工作流引擎|可视化多 Agent 协作+节点扩展教程
- 动漫百科全书API:你准备好探索动漫世界的无限可能了吗?
- Claude API在中国停用后的迁移与替代方案详解
- Grafana API 入门指南:自动化仪表板管理与高级功能
- 常用的14条API文档编写基本准则
- 如何获取 Kimi K2 API 密钥(分步指南)
- 为什么需要隐藏您的 API Key 密钥
热门推荐
一个账号试用1000+ API
助力AI无缝链接物理世界 · 无需多次注册