实现一个简单的生成对抗网络(Generative Adversarial Network, GAN)通常涉及以下几个步骤:
定义生成器(Generator)和判别器(Discriminator)网络。定义损失函数和优化器。训练生成器和判别器。下面是一个使用PyTorch实现简单GAN的示例。这个示例将使用MNIST数据集来生成手写数字图片。
1. 导入必要的库import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets as dsets import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np2. 定义生成器和判别器
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, output_size),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, output_size),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x) 3. 定义损失函数和优化器 # 超参数
batch_size = 100
learning_rate = 0.0002
num_epochs = 200
latent_size = 64
hidden_size = 256
image_size = 784 # 28*28
num_classes = 1
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
mnist = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
# 实例化生成器和判别器
G = Generator(latent_size, hidden_size, image_size)
D = Discriminator(image_size, hidden_size, num_classes)
# 损失函数和优化器
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate) 4. 训练生成器和判别器 # 训练GAN
total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
# 构建标签
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 训练判别器
outputs = D(images.view(batch_size, -1))
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
z = torch.randn(batch_size, latent_size)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 训练生成器
z = torch.randn(batch_size, latent_size)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 200 == 0:
print(f'Epoch [{epoch}/{num_epochs}], Step [{i+1}/{total_step}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}, D(x): {real_score.mean().item()}, D(G(z)): {fake_score.mean().item()}')
# 保存生成的图片
if (epoch+1) == 1 or (epoch+1) % 20 == 0:
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(fake_images, f'./samples/fake_images-{epoch+1}.png') 5. 可视化生成的图片 import torchvision.utils as vutils
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 加载并显示生成的图片
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
grid = vutils.make_grid(fake_images, padding=2, normalize=True)
imshow(grid) 以上代码实现了一个简单的GAN,用于生成MNIST手写数字图片。你可以根据需要调整超参数和网络结构以获得更好的生成效果。
网友回复
有没有不依赖embedding向量的RAG技术?
有没有支持实时打断语音通话并后台帮你执行任何的ai模型?
开源ai大模型文件格式GGUF、MLX、Safetensors、 ONNX 有什么区别?
出海挣钱支付收款PayPal、Wise 、PingPong、Stripe如何选择?
如何实现类似google的图片隐形水印添加和识别技术?
linux上如何运行任意windows程序?
ai能写出比黑客还厉害的零日漏洞等攻击工具攻击任意软件系统工程?
js如何获取浏览器的音频上下文指纹、Canvas指纹、WebGL渲染特征?
为啥ai开始抛弃markdown文本,重新偏好html文本了?
网站有没有办法鉴别访问请求是由ai操控chrome-devtools-mcp发出的?


