Python实现生成对抗网络:生成逼真数据
2025-01-09
Python实现GAN (生成对抗网络) – 从0到1的深度学习之旅
嘿,小伙伴们!今天咱们要玩一个有趣的项目 – 用Python实现GAN网络。这个项目会帮你理解如何训练AI来生成超逼真的数据。我们会用MNIST手写数字数据集来演示,让AI学会画数字!
第一步:导入必要的包
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
第二步:准备数据集
# 数据预处理和加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transform,
download=True
)
# 创建数据加载器
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True
)
第三步:构建生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 生成器结构:用简单的全连接层
self.gen = nn.Sequential(
# 输入是随机噪声(latent_dim)
nn.Linear(100, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
# 输出层,生成28*28=784维的图像
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
第四步:构建判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 判别器结构
self.disc = nn.Sequential(
# 输入是展平的图像(784维)
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
# 输出一个概率值
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)
第五步:训练模型
# 初始化模型
generator = Generator()
discriminator = Discriminator()
# 损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练参数
num_epochs = 100
latent_dim = 100
fixed_noise = torch.randn(16, latent_dim) # 用于可视化
# 训练循环
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
# 准备真实和虚假的标签
real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)
# 展平图像
real_images = real_images.view(-1, 784)
# 训练判别器
d_optimizer.zero_grad()
output_real = discriminator(real_images)
d_loss_real = criterion(output_real, real_label)
# 生成假图像
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
output_fake = discriminator(fake_images.detach())
d_loss_fake = criterion(output_fake, fake_label)
# 计算判别器总损失
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
output_fake = discriminator(fake_images)
g_loss = criterion(output_fake, real_label)
g_loss.backward()
g_optimizer.step()
if i % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '
f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
第六步:生成和显示结果
def show_images(images):
plt.figure(figsize=(4, 4))
plt.axis(“off”)
plt.imshow(np.transpose(torchvision.utils.make_grid(
images.reshape(-1, 1, 28, 28), nrow=4, padding=2, normalize=True
).cpu(), (1, 2, 0)))
plt.show()
# 生成示例图像
with torch.no_grad():
fake_images = generator(fixed_noise)
show_images(fake_images)
小贴士:
- 增加训练稳定性:
* 使用标签平滑化技术
* 添加梯度裁剪
* 尝试不同的激活函数
- 提高生成质量:
* 增加网络层数和通道数
* 使用更复杂的网络结构(如卷积层)
* 调整学习率和批次大小
- 避免模式崩溃:
* 使用Wasserstein GAN
* 实现批次归一化
* 适当调整判别器和生成器的训练比例
这个GAN实现虽然简单,但已经能产生不错的效果了。记得调参是提升效果的关键,慢慢调整,你一定能生成出更逼真的图像!
文章转自微信公众号@寒江映孤月
同话题下的热门内容