实现一个简单的生成对抗网络(Generative Adversarial Network, GAN)通常涉及以下几个步骤:
定义生成器(Generator)和判别器(Discriminator)网络。定义损失函数和优化器。训练生成器和判别器。下面是一个使用PyTorch实现简单GAN的示例。这个示例将使用MNIST数据集来生成手写数字图片。
1. 导入必要的库import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets as dsets import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np2. 定义生成器和判别器
# 定义生成器 class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.main = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(True), nn.Linear(hidden_size, hidden_size), nn.ReLU(True), nn.Linear(hidden_size, output_size), nn.Tanh() ) def forward(self, x): return self.main(x) # 定义判别器 class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Linear(input_size, hidden_size), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_size, output_size), nn.Sigmoid() ) def forward(self, x): return self.main(x)3. 定义损失函数和优化器
# 超参数 batch_size = 100 learning_rate = 0.0002 num_epochs = 200 latent_size = 64 hidden_size = 256 image_size = 784 # 28*28 num_classes = 1 # 加载MNIST数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) mnist = dsets.MNIST(root='./data', train=True, transform=transform, download=True) data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True) # 实例化生成器和判别器 G = Generator(latent_size, hidden_size, image_size) D = Discriminator(image_size, hidden_size, num_classes) # 损失函数和优化器 criterion = nn.BCELoss() d_optimizer = optim.Adam(D.parameters(), lr=learning_rate) g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)4. 训练生成器和判别器
# 训练GAN total_step = len(data_loader) for epoch in range(num_epochs): for i, (images, _) in enumerate(data_loader): # 构建标签 real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1) # 训练判别器 outputs = D(images.view(batch_size, -1)) d_loss_real = criterion(outputs, real_labels) real_score = outputs z = torch.randn(batch_size, latent_size) fake_images = G(z) outputs = D(fake_images) d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs d_loss = d_loss_real + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # 训练生成器 z = torch.randn(batch_size, latent_size) fake_images = G(z) outputs = D(fake_images) g_loss = criterion(outputs, real_labels) g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if (i+1) % 200 == 0: print(f'Epoch [{epoch}/{num_epochs}], Step [{i+1}/{total_step}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}, D(x): {real_score.mean().item()}, D(G(z)): {fake_score.mean().item()}') # 保存生成的图片 if (epoch+1) == 1 or (epoch+1) % 20 == 0: fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28) save_image(fake_images, f'./samples/fake_images-{epoch+1}.png')5. 可视化生成的图片
import torchvision.utils as vutils def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 加载并显示生成的图片 fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28) grid = vutils.make_grid(fake_images, padding=2, normalize=True) imshow(grid)
以上代码实现了一个简单的GAN,用于生成MNIST手写数字图片。你可以根据需要调整超参数和网络结构以获得更好的生成效果。
网友回复
python如何实现torrent的服务端进行文件分发p2p下载?
如何在浏览器中录制摄像头和麦克风数据为mp4视频保存下载本地?
go如何编写一个类似docker的linux的虚拟容器?
python如何写一个bittorrent的种子下载客户端?
ai能通过看一个网页的交互过程视频自主模仿复制网页编写代码吗?
ai先写功能代码通过chrome mcp来进行测试功能最后ai美化页面这个流程能行吗?
vue在手机端上下拖拽元素的时候如何禁止父元素及body的滚动导致无法拖拽完成?
使用tailwindcss如何去掉响应式自适应?
有没有直接在浏览器中运行的离线linux系统?
nginx如何保留post或get数据进行url重定向?