所有文章 > AI驱动 > Python实现生成对抗网络:生成逼真数据

Python实现生成对抗网络:生成逼真数据

Python实现GAN (生成对抗网络) – 从0到1的深度学习之旅

嘿,小伙伴们!今天咱们要玩一个有趣的项目 – 用Python实现GAN网络。这个项目会帮你理解如何训练AI来生成超逼真的数据。我们会用MNIST手写数字数据集来演示,让AI学会画数字!

第一步:导入必要的包

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import numpy as np

import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现

torch.manual_seed(42)

第二步:准备数据集

# 数据预处理和加载

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.5,), (0.5,))

])

# 加载MNIST数据集

train_dataset = torchvision.datasets.MNIST(

root='./data',

train=True,

transform=transform,

download=True

)

# 创建数据加载器

train_loader = DataLoader(

dataset=train_dataset,

batch_size=64,

shuffle=True

)

第三步:构建生成器网络

class Generator(nn.Module):

def __init__(self):

super(Generator, self).__init__()

# 生成器结构:用简单的全连接层

self.gen = nn.Sequential(

# 输入是随机噪声(latent_dim)

nn.Linear(100, 256),

nn.LeakyReLU(0.2),

nn.BatchNorm1d(256),

nn.Linear(256, 512),

nn.LeakyReLU(0.2),

nn.BatchNorm1d(512),

nn.Linear(512, 1024),

nn.LeakyReLU(0.2),

nn.BatchNorm1d(1024),

# 输出层,生成28*28=784维的图像

nn.Linear(1024, 784),

nn.Tanh()

)

def forward(self, x):

return self.gen(x)

第四步:构建判别器网络

class Discriminator(nn.Module):

def __init__(self):

super(Discriminator, self).__init__()

# 判别器结构

self.disc = nn.Sequential(

# 输入是展平的图像(784维)

nn.Linear(784, 1024),

nn.LeakyReLU(0.2),

nn.Dropout(0.3),

nn.Linear(1024, 512),

nn.LeakyReLU(0.2),

nn.Dropout(0.3),

nn.Linear(512, 256),

nn.LeakyReLU(0.2),

nn.Dropout(0.3),

# 输出一个概率值

nn.Linear(256, 1),

nn.Sigmoid()

)

def forward(self, x):

return self.disc(x)

第五步:训练模型

# 初始化模型

generator = Generator()

discriminator = Discriminator()

# 损失函数和优化器

criterion = nn.BCELoss()

g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 训练参数

num_epochs = 100

latent_dim = 100

fixed_noise = torch.randn(16, latent_dim) # 用于可视化

# 训练循环

for epoch in range(num_epochs):

for i, (real_images, _) in enumerate(train_loader):

batch_size = real_images.size(0)

# 准备真实和虚假的标签

real_label = torch.ones(batch_size, 1)

fake_label = torch.zeros(batch_size, 1)

# 展平图像

real_images = real_images.view(-1, 784)

# 训练判别器

d_optimizer.zero_grad()

output_real = discriminator(real_images)

d_loss_real = criterion(output_real, real_label)

# 生成假图像

noise = torch.randn(batch_size, latent_dim)

fake_images = generator(noise)

output_fake = discriminator(fake_images.detach())

d_loss_fake = criterion(output_fake, fake_label)

# 计算判别器总损失

d_loss = d_loss_real + d_loss_fake

d_loss.backward()

d_optimizer.step()

# 训练生成器

g_optimizer.zero_grad()

output_fake = discriminator(fake_images)

g_loss = criterion(output_fake, real_label)

g_loss.backward()

g_optimizer.step()

if i % 100 == 0:

print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '

f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

第六步:生成和显示结果

def show_images(images):

plt.figure(figsize=(4, 4))

plt.axis(“off”)

plt.imshow(np.transpose(torchvision.utils.make_grid(

images.reshape(-1, 1, 28, 28), nrow=4, padding=2, normalize=True

).cpu(), (1, 2, 0)))

plt.show()

# 生成示例图像

with torch.no_grad():

fake_images = generator(fixed_noise)

show_images(fake_images)

小贴士:

  1. 增加训练稳定性:

* 使用标签平滑化技术

* 添加梯度裁剪

* 尝试不同的激活函数

  1. 提高生成质量:

* 增加网络层数和通道数

* 使用更复杂的网络结构(如卷积层)

* 调整学习率和批次大小

  1. 避免模式崩溃:

* 使用Wasserstein GAN

* 实现批次归一化

* 适当调整判别器和生成器的训练比例

这个GAN实现虽然简单,但已经能产生不错的效果了。记得调参是提升效果的关键,慢慢调整,你一定能生成出更逼真的图像!

文章转自微信公众号@寒江映孤月