对抗生成网络:AI的想象力
第19讲:对抗生成网络:AI的想象力
目录
1. GAN:让AI学会"造假"
1.1 什么是生成对抗网络
通俗理解:GAN就像假币制造者(生成器)和警察(判别器)的博弈。
- 制造者不断改进假币,试图骗过警察
- 警察不断提高鉴别能力,试图抓住假币
- 最终:制造者的技术炉火纯青,警察也无法分辨真假 → 假币和真币一模一样
真实数据分布(如真实人脸照片)
│
┌────┴────┐
│ │
判别器(D) 生成器(G)
│ │
└────┬────┘
│
博弈均衡:G生成的样本与真实数据不可区分
1.2 GAN的核心价值
| 能力 | 应用 | 示例 |
|---|---|---|
| 数据生成 | 扩充训练数据 | 生成稀有病症的医学影像 |
| 图像翻译 | 风格转换 | 照片→油画、白天→黑夜 |
| 超分辨率 | 图像增强 | 模糊照片→高清 |
| 图像修复 | 补全缺失 | 老照片修复、去水印 |
| 特征学习 | 无监督预训练 | 用GAN特征做分类 |
2. 生成器与判别器:博弈论视角
2.1 数学框架
生成器 G:输入随机噪声 z,输出生成图像 G(z)
判别器 D:输入图像 x,输出是真实图像的概率 D(x)
博弈目标(Minimax Game):
min G max D V ( D , G ) = E x ∼ p d a t a [ log D ( x ) ] + E z ∼ p z [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz[log(1−D(G(z)))]
直观理解:
- D的目标:最大化 V,即让 D(x)→1(真图判断为真),D(G(z))→0(假图判断为假)
- G的目标:最小化 V,即让 D(G(z))→1(骗过判别器)
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# ========== 极简GAN:用全连接层生成1D数据 ==========
# 目标:学习生成服从正态分布 N(0, 1) 的数据
class SimpleGenerator(nn.Module):
"""
极简生成器:噪声 → 全连接 → 生成数据
"""
def __init__(self, noise_dim=10, hidden_dim=50):
super().__init__()
self.net = nn.Sequential(
nn.Linear(noise_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # 输出1维数据
)
self.noise_dim = noise_dim
def forward(self, z):
return self.net(z)
def sample(self, num_samples, device='cpu'):
"""生成样本"""
z = torch.randn(num_samples, self.noise_dim, device=device)
return self.forward(z)
class SimpleDiscriminator(nn.Module):
"""
极简判别器:数据 → 全连接 → 真假概率
"""
def __init__(self, hidden_dim=50):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # 输出概率
)
def forward(self, x):
return self.net(x)
# 训练极简GAN
def train_simple_gan():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
G = SimpleGenerator().to(device)
D = SimpleDiscriminator().to(device)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
# 真实数据:N(0, 1)
def sample_real_data(batch_size):
return torch.randn(batch_size, 1, device=device)
epochs = 5000
batch_size = 64
# 记录生成数据的分布变化
snapshots = []
for epoch in range(epochs):
# ========== 训练判别器 ==========
for _ in range(5): # 判别器多训练几步
d_optimizer.zero_grad()
# 真实数据
real_data = sample_real_data(batch_size)
real_pred = D(real_data)
d_loss_real = -torch.log(real_pred + 1e-8).mean()
# 生成数据
fake_data = G.sample(batch_size, device).detach()
fake_pred = D(fake_data)
d_loss_fake = -torch.log(1 - fake_pred + 1e-8).mean()
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# ========== 训练生成器 ==========
g_optimizer.zero_grad()
fake_data = G.sample(batch_size, device)
fake_pred = D(fake_data)
# G希望D(fake)接近1
g_loss = -torch.log(fake_pred + 1e-8).mean()
g_loss.backward()
g_optimizer.step()
# 记录
if epoch % 500 == 0:
with torch.no_grad():
samples = G.sample(1000, device).cpu().numpy()
snapshots.append((epoch, samples))
print(f"Epoch {epoch}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
# 可视化分布演化
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for idx, (epoch, samples) in enumerate(snapshots):
ax = axes[idx]
ax.hist(samples, bins=30, density=True, alpha=0.7, color='blue', label='Generated')
# 真实分布
x_range = np.linspace(-3, 3, 100)
real_pdf = (1/np.sqrt(2*np.pi)) * np.exp(-0.5 * x_range**2)
ax.plot(x_range, real_pdf, 'r-', linewidth=2, label='Real N(0,1)')
ax.set_title(f'Epoch {epoch}')
ax.set_xlim(-3, 3)
ax.set_ylim(0, 0.5)
if idx == 0:
ax.legend()
plt.suptitle('GAN Distribution Evolution', fontsize=14)
plt.tight_layout()
plt.show()
# 运行
# train_simple_gan()
2.2 训练动态可视化
Epoch 0: 生成数据分布混乱(均匀或单峰)
↓
Epoch 500: 开始聚集,但位置偏移
↓
Epoch 1000: 接近目标分布,但方差不对
↓
Epoch 2000: 基本匹配,但仍有偏差
↓
Epoch 5000: 与N(0,1)几乎重合!
3. DCGAN:深度卷积GAN
3.1 从全连接到卷积
问题:全连接GAN只能生成简单数据(如1D分布),图像需要卷积结构。
DCGAN核心:用转置卷积(ConvTranspose)上采样生成图像,用卷积判别真假。
class DCGANGenerator(nn.Module):
"""
DCGAN生成器:噪声 → 转置卷积 → 图像
输入: [N, 100, 1, 1](100维噪声)
输出: [N, 3, 64, 64](64x64 RGB图像)
"""
def __init__(self, noise_dim=100, num_channels=64, image_channels=3):
super().__init__()
self.noise_dim = noise_dim
# 逐层上采样:1x1 → 4x4 → 8x8 → 16x16 → 32x32 → 64x64
self.net = nn.Sequential(
# 输入: [N, 100, 1, 1]
nn.ConvTranspose2d(noise_dim, num_channels * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_channels * 8),
nn.ReLU(True),
# 输出: [N, 512, 4, 4]
nn.ConvTranspose2d(num_channels * 8, num_channels * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_channels * 4),
nn.ReLU(True),
# 输出: [N, 256, 8, 8]
nn.ConvTranspose2d(num_channels * 4, num_channels * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_channels * 2),
nn.ReLU(True),
# 输出: [N, 128, 16, 16]
nn.ConvTranspose2d(num_channels * 2, num_channels, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_channels),
nn.ReLU(True),
# 输出: [N, 64, 32, 32]
nn.ConvTranspose2d(num_channels, image_channels, 4, 2, 1, bias=False),
nn.Tanh() # 输出[-1, 1]
# 输出: [N, 3, 64, 64]
)
def forward(self, z):
# z: [N, noise_dim] → [N, noise_dim, 1, 1]
z = z.view(z.size(0), z.size(1), 1, 1)
return self.net(z)
def sample(self, num_samples, device='cpu'):
z = torch.randn(num_samples, self.noise_dim, device=device)
return self.forward(z)
class DCGANDiscriminator(nn.Module):
"""
DCGAN判别器:图像 → 卷积 → 真假概率
输入: [N, 3, 64, 64]
输出: [N, 1](概率)
"""
def __init__(self, num_channels=64, image_channels=3):
super().__init__()
# 逐层下采样:64x64 → 32x32 → 16x16 → 8x8 → 4x4 → 1x1
self.net = nn.Sequential(
# 输入: [N, 3, 64, 64]
nn.Conv2d(image_channels, num_channels, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [N, 64, 32, 32]
nn.Conv2d(num_channels, num_channels * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_channels * 2),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [N, 128, 16, 16]
nn.Conv2d(num_channels * 2, num_channels * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_channels * 4),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [N, 256, 8, 8]
nn.Conv2d(num_channels * 4, num_channels * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_channels * 8),
nn.LeakyReLU(0.2, inplace=True),
# 输出: [N, 512, 4, 4]
nn.Conv2d(num_channels * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# 输出: [N, 1, 1, 1] → squeeze → [N, 1]
)
def forward(self, x):
return self.net(x).squeeze()
# 测试
G = DCGANGenerator()
D = DCGANDiscriminator()
z = torch.randn(2, 100)
fake_img = G(z)
real_pred = D(fake_img)
print(f"生成器输入: {z.shape}")
print(f"生成器输出: {fake_img.shape}") # [2, 3, 64, 64]
print(f"判别器输出: {real_pred.shape}") # [2, 1]
3.2 DCGAN的设计原则
| 设计选择 | 原因 | 效果 |
|---|---|---|
| 转置卷积 | 上采样生成图像 | 学习插值,比插值更灵活 |
| BatchNorm | 稳定训练 | 防止模式坍塌 |
| ReLU(G) / LeakyReLU(D) | 避免梯度消失 | 负值有微小梯度 |
| Tanh输出 | 归一化到[-1,1] | 匹配归一化输入 |
| 无全连接层 | 保留空间结构 | 更好的空间一致性 |
| 对称结构 | G和D镜像 | 平衡训练 |
4. 训练技巧与模式坍塌
4.1 模式坍塌(Mode Collapse)
问题:生成器只学会生成某几种样本,忽略其他模式。
正常GAN:生成各种数字 0,1,2,3,4,5,6,7,8,9
↓
模式坍塌:只生成数字"1"(最容易骗过判别器)
↓
严重坍塌:所有样本几乎相同
4.2 解决模式坍塌的技巧
class GANTrainer:
"""
稳定的GAN训练器,集成多种防坍塌技巧
"""
def __init__(self, G, D, device='cuda'):
self.G = G.to(device)
self.D = D.to(device)
self.device = device
# 技巧1:不同的学习率(G通常需要更大)
self.g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 技巧2:标签平滑(防止D过于自信)
self.real_label = 0.9 # 不是1.0
self.fake_label = 0.1 # 不是0.0
def train_step(self, real_images):
batch_size = real_images.size(0)
# ========== 训练判别器 ==========
for _ in range(1): # 技巧3:D/G训练比例(有时D需要多训练)
self.d_optimizer.zero_grad()
# 真实数据
real_pred = self.D(real_images)
d_loss_real = nn.functional.binary_cross_entropy(
real_pred,
torch.full_like(real_pred, self.real_label)
)
# 生成数据
z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
fake_images = self.G(z).detach()
fake_pred = self.D(fake_images)
d_loss_fake = nn.functional.binary_cross_entropy(
fake_pred,
torch.full_like(fake_pred, self.fake_label)
)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
self.d_optimizer.step()
# ========== 训练生成器 ==========
self.g_optimizer.zero_grad()
# 技巧4:使用非饱和损失(-log D(G(z)) 而非 log(1-D(G(z))))
z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
fake_images = self.G(z)
fake_pred = self.D(fake_images)
# 非饱和损失:G希望D(fake)接近1
g_loss = nn.functional.binary_cross_entropy(
fake_pred,
torch.full_like(fake_pred, self.real_label) # 目标是1!
)
g_loss.backward()
self.g_optimizer.step()
return {
'd_loss': d_loss.item(),
'g_loss': g_loss.item(),
'real_acc': (real_pred > 0.5).float().mean().item(),
'fake_acc': (fake_pred < 0.5).float().mean().item()
}
# 技巧5:监控模式坍塌(检查生成样本的多样性)
def check_mode_collapse(self, num_samples=1000):
with torch.no_grad():
z = torch.randn(num_samples, self.G.noise_dim, device=self.device)
samples = self.G(z)
# 计算样本间的平均距离
distances = torch.cdist(samples.view(num_samples, -1),
samples.view(num_samples, -1))
avg_distance = distances.mean().item()
# 距离太小 = 模式坍塌
print(f"Average sample distance: {avg_distance:.4f}")
if avg_distance < 0.1:
print("⚠️ Warning: Possible mode collapse detected!")
return avg_distance
# 技巧6:使用Wasserstein距离(WGAN)
class WGANTrainer(GANTrainer):
"""
WGAN:用Wasserstein距离替代JS散度,训练更稳定
"""
def __init__(self, G, D, device='cuda'):
super().__init__(G, D, device)
# WGAN需要RMSprop或SGD(不用Adam的动量)
self.g_optimizer = optim.RMSprop(G.parameters(), lr=0.00005)
self.d_optimizer = optim.RMSprop(D.parameters(), lr=0.00005)
# 权重裁剪(WGAN原始方法)
self.clip_value = 0.01
def train_step(self, real_images):
batch_size = real_images.size(0)
# 训练判别器(更多次,WGAN需要)
for _ in range(5):
self.d_optimizer.zero_grad()
real_pred = self.D(real_images)
z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
fake_images = self.G(z).detach()
fake_pred = self.D(fake_images)
# Wasserstein距离:最大化 E[D(real)] - E[D(fake)]
d_loss = -(real_pred.mean() - fake_pred.mean())
d_loss.backward()
self.d_optimizer.step()
# 权重裁剪(保证Lipschitz约束)
for p in self.D.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)
# 训练生成器
self.g_optimizer.zero_grad()
z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
fake_images = self.G(z)
fake_pred = self.D(fake_images)
# G最大化 E[D(fake)]
g_loss = -fake_pred.mean()
g_loss.backward()
self.g_optimizer.step()
return {'d_loss': d_loss.item(), 'g_loss': g_loss.item()}
5. 损失函数进化
5.1 GAN损失函数演进
| 版本 | 损失函数 | 问题 | 改进 |
|---|---|---|---|
| 原始GAN | JS散度 | 梯度消失、模式坍塌 | 基础 |
| WGAN | Wasserstein距离 | 需要权重裁剪 | 训练稳定 |
| WGAN-GP | + 梯度惩罚 | 权重裁剪限制容量 | 更好的Lipschitz约束 |
| LSGAN | 最小二乘损失 | 避免sigmoid饱和 | 更平滑的梯度 |
| BEGAN | BEGAN损失 | 自编码器判别器 | 自动平衡 |
| SAGAN | Hinge Loss + 注意力 | 长距离依赖 | 更好的细节 |
5.2 WGAN-GP实现
class WGANGPTrainer(GANTrainer):
"""
WGAN-GP:使用梯度惩罚替代权重裁剪
"""
def __init__(self, G, D, device='cuda', lambda_gp=10):
super().__init__(G, D, device)
self.lambda_gp = lambda_gp
self.g_optimizer = optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
self.d_optimizer = optim.Adam(D.parameters(), lr=0.0001, betas=(0.0, 0.9))
def compute_gradient_penalty(self, real_images, fake_images):
"""
计算梯度惩罚:保证判别器的梯度范数接近1
"""
batch_size = real_images.size(0)
# 随机插值
alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
interpolates = alpha * real_images + (1 - alpha) * fake_images
interpolates.requires_grad_(True)
# 判别器对插值的输出
d_interpolates = self.D(interpolates)
# 计算梯度
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates),
create_graph=True,
retain_graph=True
)[0]
# 梯度范数
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# 惩罚项:(||grad|| - 1)^2
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
def train_step(self, real_images):
batch_size = real_images.size(0)
# 训练判别器
for _ in range(5):
self.d_optimizer.zero_grad()
real_pred = self.D(real_images)
z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
fake_images = self.G(z).detach()
fake_pred = self.D(fake_images)
# Wasserstein距离
d_loss = -(real_pred.mean() - fake_pred.mean())
# 梯度惩罚
gp = self.compute_gradient_penalty(real_images, fake_images)
d_loss += self.lambda_gp * gp
d_loss.backward()
self.d_optimizer.step()
# 训练生成器
self.g_optimizer.zero_grad()
z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
fake_images = self.G(z)
fake_pred = self.D(fake_images)
g_loss = -fake_pred.mean()
g_loss.backward()
self.g_optimizer.step()
return {
'd_loss': d_loss.item(),
'g_loss': g_loss.item(),
'gp': gp.item()
}
6. StyleGAN:控制生成风格
6.1 核心创新
问题:传统GAN的噪声输入直接控制全局,无法精细控制生成图像的风格(如头发颜色、脸型)。
StyleGAN解决:引入映射网络和自适应实例归一化(AdaIN),将噪声映射为风格向量,在不同层注入不同风格。
class MappingNetwork(nn.Module):
"""
映射网络:将随机噪声z映射为中间潜在空间w
实现风格解耦
"""
def __init__(self, z_dim=512, w_dim=512, num_layers=8):
super().__init__()
layers = []
for i in range(num_layers):
in_dim = z_dim if i == 0 else w_dim
layers.extend([
nn.Linear(in_dim, w_dim),
nn.LeakyReLU(0.2)
])
self.net = nn.Sequential(*layers)
def forward(self, z):
# 归一化输入(防止极端值)
z = z / (z.norm(dim=1, keepdim=True) + 1e-8)
return self.net(z)
class AdaIN(nn.Module):
"""
自适应实例归一化:用风格向量控制生成
"""
def __init__(self, num_features, w_dim=512):
super().__init__()
# 从w生成缩放和偏移
self.style_scale = nn.Linear(w_dim, num_features)
self.style_bias = nn.Linear(w_dim, num_features)
def forward(self, x, w):
# x: [N, C, H, W]
# w: [N, w_dim] 风格向量
# 实例归一化
x_norm = nn.functional.instance_norm(x)
# 风格化
gamma = self.style_scale(w).unsqueeze(2).unsqueeze(3) # [N, C, 1, 1]
beta = self.style_bias(w).unsqueeze(2).unsqueeze(3)
return gamma * x_norm + beta
class StyleGANGenerator(nn.Module):
"""
StyleGAN生成器简化版
"""
def __init__(self, z_dim=512, w_dim=512, image_size=64):
super().__init__()
self.mapping = MappingNetwork(z_dim, w_dim)
# 从常量输入开始(不是噪声)
self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4))
# 逐层上采样,每层注入不同风格
self.blocks = nn.ModuleList()
in_channels = 512
# 4x4 → 8x8 → 16x16 → 32x32 → 64x64
for i in range(int(np.log2(image_size)) - 2):
out_channels = min(512, 512 // (2 ** i))
self.blocks.append(nn.ModuleDict({
'conv1': nn.Conv2d(in_channels, out_channels, 3, padding=1),
'adain1': AdaIN(out_channels, w_dim),
'conv2': nn.Conv2d(out_channels, out_channels, 3, padding=1),
'adain2': AdaIN(out_channels, w_dim),
'upsample': nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
}))
in_channels = out_channels
self.to_rgb = nn.Conv2d(in_channels, 3, 1)
self.activation = nn.Tanh()
def forward(self, z, truncation_psi=0.7):
# 映射到w空间
w = self.mapping(z)
# 截断技巧(控制多样性-质量权衡)
if truncation_psi < 1.0:
w = self.w_avg + truncation_psi * (w - self.w_avg)
# 从常量开始
x = self.const_input.repeat(z.size(0), 1, 1, 1)
# 逐层注入风格
for block in self.blocks:
x = block['adain1'](x, w)
x = torch.relu(block['conv1'](x))
x = block['adain2'](x, w)
x = torch.relu(block['conv2'](x))
x = block['upsample'](x)
return self.activation(self.to_rgb(x))
# 风格混合:用两个w控制不同层
def style_mixing(generator, z1, z2, layer_idx=4):
"""
用z1控制低层(姿态、脸型),z2控制高层(颜色、纹理)
"""
w1 = generator.mapping(z1)
w2 = generator.mapping(z2)
# 前layer_idx层用w1,后面用w2
# 实际实现需要修改forward,传入多个w
pass
7. 实战:生成手写数字
7.1 完整训练代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]
])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)
# 为MNIST设计的轻量DCGAN(28x28灰度图)
class MNISTGenerator(nn.Module):
def __init__(self, noise_dim=100):
super().__init__()
self.noise_dim = noise_dim
self.net = nn.Sequential(
# [N, 100, 1, 1] → [N, 128, 7, 7]
nn.ConvTranspose2d(noise_dim, 128, 7, 1, 0),
nn.BatchNorm2d(128),
nn.ReLU(True),
# → [N, 64, 14, 14]
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# → [N, 1, 28, 28]
nn.ConvTranspose2d(64, 1, 4, 2, 1),
nn.Tanh()
)
def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
return self.net(z)
class MNISTDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# [N, 1, 28, 28] → [N, 64, 14, 14]
nn.Conv2d(1, 64, 4, 2, 1),
nn.LeakyReLU(0.2, True),
# → [N, 128, 7, 7]
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True),
# → [N, 1, 1, 1]
nn.Conv2d(128, 1, 7, 1, 0),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x).squeeze()
# 训练
def train_mnist_gan(epochs=50):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = MNISTGenerator().to(device)
D = MNISTDiscriminator().to(device)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
# 记录生成图像
fixed_noise = torch.randn(64, 100, device=device)
for epoch in range(epochs):
for batch_idx, (real_images, _) in enumerate(dataloader):
real_images = real_images.to(device)
batch_size = real_images.size(0)
# 真实标签和假标签(带噪声,防止过自信)
real_labels = torch.ones(batch_size, device=device) * 0.9
fake_labels = torch.zeros(batch_size, device=device) + 0.1
# ========== 训练判别器 ==========
d_optimizer.zero_grad()
# 真实数据
real_output = D(real_images)
d_loss_real = criterion(real_output, real_labels)
# 生成数据
noise = torch.randn(batch_size, 100, device=device)
fake_images = G(noise).detach()
fake_output = D(fake_images)
d_loss_fake = criterion(fake_output, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# ========== 训练生成器 ==========
g_optimizer.zero_grad()
noise = torch.randn(batch_size, 100, device=device)
fake_images = G(noise)
fake_output = D(fake_images)
# 非饱和损失
g_loss = criterion(fake_output, real_labels) # 目标是骗过D
g_loss.backward()
g_optimizer.step()
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}: "
f"D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
# 每个epoch保存生成图像
with torch.no_grad():
fake_samples = G(fixed_noise).cpu()
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
for idx, ax in enumerate(axes.flat):
ax.imshow(fake_samples[idx].squeeze(), cmap='gray', vmin=-1, vmax=1)
ax.axis('off')
plt.suptitle(f'Epoch {epoch}', fontsize=14)
plt.tight_layout()
plt.savefig(f'gan_samples/epoch_{epoch:03d}.png')
plt.close()
return G
# 运行
# G = train_mnist_gan(epochs=50)
7.2 插值实验:潜在空间的平滑性
def latent_space_interpolation(G, device='cuda', num_steps=10):
"""
在潜在空间中插值,观察生成图像的平滑变化
"""
G.eval()
# 两个随机噪声向量
z1 = torch.randn(1, 100, device=device)
z2 = torch.randn(1, 100, device=device)
# 线性插值
alphas = torch.linspace(0, 1, num_steps)
fig, axes = plt.subplots(1, num_steps, figsize=(20, 2))
with torch.no_grad():
for idx, alpha in enumerate(alphas):
z = alpha * z1 + (1 - alpha) * z2
img = G(z).cpu().squeeze()
axes[idx].imshow(img, cmap='gray', vmin=-1, vmax=1)
axes[idx].set_title(f'α={alpha:.2f}')
axes[idx].axis('off')
plt.suptitle('Latent Space Interpolation')
plt.tight_layout()
plt.show()
# latent_space_interpolation(G)
8. 实战:人脸动漫化
8.1 CycleGAN:无配对图像翻译
class ResidualBlock(nn.Module):
"""CycleGAN的残差块"""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.in1 = nn.InstanceNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.in2 = nn.InstanceNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
return out + residual
class Generator(nn.Module):
"""
CycleGAN生成器:编码器 → 转换 → 解码器
"""
def __init__(self, in_channels=3, num_residual=9):
super().__init__()
# 下采样编码
self.down = nn.Sequential(
nn.Conv2d(in_channels, 64, 7, padding=3, padding_mode='reflect'),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.ReLU(True)
)
# 残差转换
self.residuals = nn.Sequential(*[
ResidualBlock(256) for _ in range(num_residual)
])
# 上采样解码
self.up = nn.Sequential(
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, in_channels, 7, padding=3, padding_mode='reflect'),
nn.Tanh()
)
def forward(self, x):
x = self.down(x)
x = self.residuals(x)
x = self.up(x)
return x
class CycleGAN(nn.Module):
"""
CycleGAN:两个生成器 + 两个判别器
G: X→Y (真人→动漫)
F: Y→X (动漫→真人)
Dx: 判别X是否真实
Dy: 判别Y是否真实
"""
def __init__(self):
super().__init__()
self.G = Generator() # 真人 → 动漫
self.F = Generator() # 动漫 → 真人
self.Dx = Discriminator() # 判别真人
self.Dy = Discriminator() # 判别动漫
def compute_cycle_loss(self, real_x, real_y):
"""
循环一致性损失:F(G(x)) ≈ x, G(F(y)) ≈ y
"""
fake_y = self.G(real_x)
rec_x = self.F(fake_y)
fake_x = self.F(real_y)
rec_y = self.G(fake_x)
cycle_loss = (torch.abs(rec_x - real_x).mean() +
torch.abs(rec_y - real_y).mean())
return cycle_loss
def compute_identity_loss(self, real_x, real_y):
"""
身份损失:G(y) ≈ y, F(x) ≈ x(防止颜色变化)
"""
idt_y = self.G(real_y)
idt_x = self.F(real_x)
identity_loss = (torch.abs(idt_y - real_y).mean() +
torch.abs(idt_x - real_x).mean())
return identity_loss
class Discriminator(nn.Module):
"""
PatchGAN判别器:判断每个patch的真假
"""
def __init__(self, in_channels=3):
super().__init__()
def discriminator_block(in_c, out_c, normalize=True):
layers = [nn.Conv2d(in_c, out_c, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_c))
layers.append(nn.LeakyReLU(0.2, True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.Conv2d(512, 1, 4, padding=1) # 输出patch的分数
)
def forward(self, x):
return self.model(x)
# CycleGAN训练(简化)
def train_cyclegan(dataloader_x, dataloader_y, epochs=200):
device = torch.device('cuda')
model = CycleGAN().to(device)
g_optimizer = optim.Adam(list(model.G.parameters()) + list(model.F.parameters()),
lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(list(model.Dx.parameters()) + list(model.Dy.parameters()),
lr=0.0002, betas=(0.5, 0.999))
for epoch in range(epochs):
for (real_x, _), (real_y, _) in zip(dataloader_x, dataloader_y):
real_x, real_y = real_x.to(device), real_y.to(device)
# 生成
fake_y = model.G(real_x)
fake_x = model.F(real_y)
# 循环重建
rec_x = model.F(fake_y)
rec_y = model.G(fake_x)
# ========== 判别器损失 ==========
d_x_real = model.Dx(real_x)
d_x_fake = model.Dx(fake_x.detach())
d_y_real = model.Dy(real_y)
d_y_fake = model.Dy(fake_y.detach())
d_loss = ((d_x_real - 1)**2).mean() + (d_x_fake**2).mean() + \
((d_y_real - 1)**2).mean() + (d_y_fake**2).mean()
# ========== 生成器损失 ==========
g_loss_gan = ((model.Dy(fake_y) - 1)**2).mean() + \
((model.Dx(fake_x) - 1)**2).mean()
cycle_loss = model.compute_cycle_loss(real_x, real_y)
identity_loss = model.compute_identity_loss(real_x, real_y)
g_loss = g_loss_gan + 10 * cycle_loss + 5 * identity_loss
# 更新
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print(f"Epoch {epoch}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
# 保存样本
if epoch % 10 == 0:
with torch.no_grad():
sample = model.G(real_x[:8])
# 保存可视化...
9. 小结
| 知识点 | 核心要点 |
|---|---|
| GAN原理 | 生成器 vs 判别器的Minimax博弈 |
| 生成器 | 噪声→图像,转置卷积上采样 |
| 判别器 | 图像→概率,卷积下采样 |
| DCGAN | 卷积GAN,BatchNorm稳定训练 |
| 模式坍塌 | 生成器只生成少数模式 |
| 解决坍塌 | 标签平滑、非饱和损失、WGAN、梯度惩罚 |
| WGAN | Wasserstein距离,训练更稳定 |
| StyleGAN | 映射网络+AdaIN,精细控制风格 |
| CycleGAN | 循环一致性,无配对数据翻译 |
| 应用 | 数据增强、风格迁移、超分辨率、图像修复 |
课后练习
-
损失对比:在MNIST上对比原始GAN、WGAN、WGAN-GP的损失曲线和生成质量。
-
插值可视化:训练MNIST GAN后,在潜在空间做2D网格插值,观察数字的连续变化。
-
条件GAN:实现CGAN(Conditional GAN),输入类别标签+噪声,控制生成特定数字。
-
模式坍塌检测:实现生成样本的多样性度量(如FID分数),监控训练过程。
-
挑战:实现StyleGAN2-ADA,用自适应判别器增强(ADA)训练自定义数据集。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)