第19讲:对抗生成网络:AI的想象力


目录

  1. GAN:让AI学会"造假"
  2. 生成器与判别器:博弈论视角
  3. DCGAN:深度卷积GAN
  4. 训练技巧与模式坍塌
  5. 损失函数进化
  6. StyleGAN:控制生成风格
  7. 实战:生成手写数字
  8. 实战:人脸动漫化
  9. 小结

1. GAN:让AI学会"造假"

1.1 什么是生成对抗网络

通俗理解:GAN就像假币制造者(生成器)和警察(判别器)的博弈

  • 制造者不断改进假币,试图骗过警察
  • 警察不断提高鉴别能力,试图抓住假币
  • 最终:制造者的技术炉火纯青,警察也无法分辨真假 → 假币和真币一模一样
真实数据分布(如真实人脸照片)
         │
    ┌────┴────┐
    │         │
  判别器(D)  生成器(G)
    │         │
    └────┬────┘
         │
    博弈均衡:G生成的样本与真实数据不可区分

1.2 GAN的核心价值

能力 应用 示例
数据生成 扩充训练数据 生成稀有病症的医学影像
图像翻译 风格转换 照片→油画、白天→黑夜
超分辨率 图像增强 模糊照片→高清
图像修复 补全缺失 老照片修复、去水印
特征学习 无监督预训练 用GAN特征做分类

2. 生成器与判别器:博弈论视角

2.1 数学框架

生成器 G:输入随机噪声 z,输出生成图像 G(z)

判别器 D:输入图像 x,输出是真实图像的概率 D(x)

博弈目标(Minimax Game)

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]

直观理解

  • D的目标:最大化 V,即让 D(x)→1(真图判断为真),D(G(z))→0(假图判断为假)
  • G的目标:最小化 V,即让 D(G(z))→1(骗过判别器)
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# ========== 极简GAN:用全连接层生成1D数据 ==========
# 目标:学习生成服从正态分布 N(0, 1) 的数据

class SimpleGenerator(nn.Module):
    """
    极简生成器:噪声 → 全连接 → 生成数据
    """
    def __init__(self, noise_dim=10, hidden_dim=50):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # 输出1维数据
        )
        self.noise_dim = noise_dim
    
    def forward(self, z):
        return self.net(z)
    
    def sample(self, num_samples, device='cpu'):
        """生成样本"""
        z = torch.randn(num_samples, self.noise_dim, device=device)
        return self.forward(z)

class SimpleDiscriminator(nn.Module):
    """
    极简判别器:数据 → 全连接 → 真假概率
    """
    def __init__(self, hidden_dim=50):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # 输出概率
        )
    
    def forward(self, x):
        return self.net(x)

# 训练极简GAN
def train_simple_gan():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    G = SimpleGenerator().to(device)
    D = SimpleDiscriminator().to(device)
    
    g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
    d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
    
    # 真实数据:N(0, 1)
    def sample_real_data(batch_size):
        return torch.randn(batch_size, 1, device=device)
    
    epochs = 5000
    batch_size = 64
    
    # 记录生成数据的分布变化
    snapshots = []
    
    for epoch in range(epochs):
        # ========== 训练判别器 ==========
        for _ in range(5):  # 判别器多训练几步
            d_optimizer.zero_grad()
            
            # 真实数据
            real_data = sample_real_data(batch_size)
            real_pred = D(real_data)
            d_loss_real = -torch.log(real_pred + 1e-8).mean()
            
            # 生成数据
            fake_data = G.sample(batch_size, device).detach()
            fake_pred = D(fake_data)
            d_loss_fake = -torch.log(1 - fake_pred + 1e-8).mean()
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
        
        # ========== 训练生成器 ==========
        g_optimizer.zero_grad()
        
        fake_data = G.sample(batch_size, device)
        fake_pred = D(fake_data)
        
        # G希望D(fake)接近1
        g_loss = -torch.log(fake_pred + 1e-8).mean()
        
        g_loss.backward()
        g_optimizer.step()
        
        # 记录
        if epoch % 500 == 0:
            with torch.no_grad():
                samples = G.sample(1000, device).cpu().numpy()
                snapshots.append((epoch, samples))
            print(f"Epoch {epoch}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
    
    # 可视化分布演化
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for idx, (epoch, samples) in enumerate(snapshots):
        ax = axes[idx]
        ax.hist(samples, bins=30, density=True, alpha=0.7, color='blue', label='Generated')
        
        # 真实分布
        x_range = np.linspace(-3, 3, 100)
        real_pdf = (1/np.sqrt(2*np.pi)) * np.exp(-0.5 * x_range**2)
        ax.plot(x_range, real_pdf, 'r-', linewidth=2, label='Real N(0,1)')
        
        ax.set_title(f'Epoch {epoch}')
        ax.set_xlim(-3, 3)
        ax.set_ylim(0, 0.5)
        if idx == 0:
            ax.legend()
    
    plt.suptitle('GAN Distribution Evolution', fontsize=14)
    plt.tight_layout()
    plt.show()

# 运行
# train_simple_gan()

2.2 训练动态可视化

Epoch 0:    生成数据分布混乱(均匀或单峰)
            ↓
Epoch 500:  开始聚集,但位置偏移
            ↓
Epoch 1000: 接近目标分布,但方差不对
            ↓
Epoch 2000: 基本匹配,但仍有偏差
            ↓
Epoch 5000: 与N(0,1)几乎重合!

3. DCGAN:深度卷积GAN

3.1 从全连接到卷积

问题:全连接GAN只能生成简单数据(如1D分布),图像需要卷积结构

DCGAN核心:用转置卷积(ConvTranspose)上采样生成图像,用卷积判别真假。

class DCGANGenerator(nn.Module):
    """
    DCGAN生成器:噪声 → 转置卷积 → 图像
    输入: [N, 100, 1, 1](100维噪声)
    输出: [N, 3, 64, 64](64x64 RGB图像)
    """
    def __init__(self, noise_dim=100, num_channels=64, image_channels=3):
        super().__init__()
        self.noise_dim = noise_dim
        
        # 逐层上采样:1x1 → 4x4 → 8x8 → 16x16 → 32x32 → 64x64
        self.net = nn.Sequential(
            # 输入: [N, 100, 1, 1]
            nn.ConvTranspose2d(noise_dim, num_channels * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(num_channels * 8),
            nn.ReLU(True),
            # 输出: [N, 512, 4, 4]
            
            nn.ConvTranspose2d(num_channels * 8, num_channels * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_channels * 4),
            nn.ReLU(True),
            # 输出: [N, 256, 8, 8]
            
            nn.ConvTranspose2d(num_channels * 4, num_channels * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_channels * 2),
            nn.ReLU(True),
            # 输出: [N, 128, 16, 16]
            
            nn.ConvTranspose2d(num_channels * 2, num_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_channels),
            nn.ReLU(True),
            # 输出: [N, 64, 32, 32]
            
            nn.ConvTranspose2d(num_channels, image_channels, 4, 2, 1, bias=False),
            nn.Tanh()  # 输出[-1, 1]
            # 输出: [N, 3, 64, 64]
        )
    
    def forward(self, z):
        # z: [N, noise_dim] → [N, noise_dim, 1, 1]
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.net(z)
    
    def sample(self, num_samples, device='cpu'):
        z = torch.randn(num_samples, self.noise_dim, device=device)
        return self.forward(z)

class DCGANDiscriminator(nn.Module):
    """
    DCGAN判别器:图像 → 卷积 → 真假概率
    输入: [N, 3, 64, 64]
    输出: [N, 1](概率)
    """
    def __init__(self, num_channels=64, image_channels=3):
        super().__init__()
        
        # 逐层下采样:64x64 → 32x32 → 16x16 → 8x8 → 4x4 → 1x1
        self.net = nn.Sequential(
            # 输入: [N, 3, 64, 64]
            nn.Conv2d(image_channels, num_channels, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [N, 64, 32, 32]
            
            nn.Conv2d(num_channels, num_channels * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_channels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [N, 128, 16, 16]
            
            nn.Conv2d(num_channels * 2, num_channels * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_channels * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [N, 256, 8, 8]
            
            nn.Conv2d(num_channels * 4, num_channels * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_channels * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出: [N, 512, 4, 4]
            
            nn.Conv2d(num_channels * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # 输出: [N, 1, 1, 1] → squeeze → [N, 1]
        )
    
    def forward(self, x):
        return self.net(x).squeeze()

# 测试
G = DCGANGenerator()
D = DCGANDiscriminator()

z = torch.randn(2, 100)
fake_img = G(z)
real_pred = D(fake_img)

print(f"生成器输入: {z.shape}")
print(f"生成器输出: {fake_img.shape}")  # [2, 3, 64, 64]
print(f"判别器输出: {real_pred.shape}")  # [2, 1]

3.2 DCGAN的设计原则

设计选择 原因 效果
转置卷积 上采样生成图像 学习插值,比插值更灵活
BatchNorm 稳定训练 防止模式坍塌
ReLU(G) / LeakyReLU(D) 避免梯度消失 负值有微小梯度
Tanh输出 归一化到[-1,1] 匹配归一化输入
无全连接层 保留空间结构 更好的空间一致性
对称结构 G和D镜像 平衡训练

4. 训练技巧与模式坍塌

4.1 模式坍塌(Mode Collapse)

问题:生成器只学会生成某几种样本,忽略其他模式。

正常GAN:生成各种数字 0,1,2,3,4,5,6,7,8,9
            ↓
模式坍塌:只生成数字"1"(最容易骗过判别器)
            ↓
严重坍塌:所有样本几乎相同

4.2 解决模式坍塌的技巧

class GANTrainer:
    """
    稳定的GAN训练器,集成多种防坍塌技巧
    """
    def __init__(self, G, D, device='cuda'):
        self.G = G.to(device)
        self.D = D.to(device)
        self.device = device
        
        # 技巧1:不同的学习率(G通常需要更大)
        self.g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        # 技巧2:标签平滑(防止D过于自信)
        self.real_label = 0.9  # 不是1.0
        self.fake_label = 0.1  # 不是0.0
    
    def train_step(self, real_images):
        batch_size = real_images.size(0)
        
        # ========== 训练判别器 ==========
        for _ in range(1):  # 技巧3:D/G训练比例(有时D需要多训练)
            self.d_optimizer.zero_grad()
            
            # 真实数据
            real_pred = self.D(real_images)
            d_loss_real = nn.functional.binary_cross_entropy(
                real_pred, 
                torch.full_like(real_pred, self.real_label)
            )
            
            # 生成数据
            z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
            fake_images = self.G(z).detach()
            fake_pred = self.D(fake_images)
            d_loss_fake = nn.functional.binary_cross_entropy(
                fake_pred,
                torch.full_like(fake_pred, self.fake_label)
            )
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            self.d_optimizer.step()
        
        # ========== 训练生成器 ==========
        self.g_optimizer.zero_grad()
        
        # 技巧4:使用非饱和损失(-log D(G(z)) 而非 log(1-D(G(z))))
        z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
        fake_images = self.G(z)
        fake_pred = self.D(fake_images)
        
        # 非饱和损失:G希望D(fake)接近1
        g_loss = nn.functional.binary_cross_entropy(
            fake_pred,
            torch.full_like(fake_pred, self.real_label)  # 目标是1!
        )
        
        g_loss.backward()
        self.g_optimizer.step()
        
        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'real_acc': (real_pred > 0.5).float().mean().item(),
            'fake_acc': (fake_pred < 0.5).float().mean().item()
        }
    
    # 技巧5:监控模式坍塌(检查生成样本的多样性)
    def check_mode_collapse(self, num_samples=1000):
        with torch.no_grad():
            z = torch.randn(num_samples, self.G.noise_dim, device=self.device)
            samples = self.G(z)
            
            # 计算样本间的平均距离
            distances = torch.cdist(samples.view(num_samples, -1), 
                                 samples.view(num_samples, -1))
            avg_distance = distances.mean().item()
            
            # 距离太小 = 模式坍塌
            print(f"Average sample distance: {avg_distance:.4f}")
            if avg_distance < 0.1:
                print("⚠️ Warning: Possible mode collapse detected!")
            
            return avg_distance

# 技巧6:使用Wasserstein距离(WGAN)
class WGANTrainer(GANTrainer):
    """
    WGAN:用Wasserstein距离替代JS散度,训练更稳定
    """
    def __init__(self, G, D, device='cuda'):
        super().__init__(G, D, device)
        
        # WGAN需要RMSprop或SGD(不用Adam的动量)
        self.g_optimizer = optim.RMSprop(G.parameters(), lr=0.00005)
        self.d_optimizer = optim.RMSprop(D.parameters(), lr=0.00005)
        
        # 权重裁剪(WGAN原始方法)
        self.clip_value = 0.01
    
    def train_step(self, real_images):
        batch_size = real_images.size(0)
        
        # 训练判别器(更多次,WGAN需要)
        for _ in range(5):
            self.d_optimizer.zero_grad()
            
            real_pred = self.D(real_images)
            
            z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
            fake_images = self.G(z).detach()
            fake_pred = self.D(fake_images)
            
            # Wasserstein距离:最大化 E[D(real)] - E[D(fake)]
            d_loss = -(real_pred.mean() - fake_pred.mean())
            d_loss.backward()
            self.d_optimizer.step()
            
            # 权重裁剪(保证Lipschitz约束)
            for p in self.D.parameters():
                p.data.clamp_(-self.clip_value, self.clip_value)
        
        # 训练生成器
        self.g_optimizer.zero_grad()
        
        z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
        fake_images = self.G(z)
        fake_pred = self.D(fake_images)
        
        # G最大化 E[D(fake)]
        g_loss = -fake_pred.mean()
        g_loss.backward()
        self.g_optimizer.step()
        
        return {'d_loss': d_loss.item(), 'g_loss': g_loss.item()}

5. 损失函数进化

5.1 GAN损失函数演进

版本 损失函数 问题 改进
原始GAN JS散度 梯度消失、模式坍塌 基础
WGAN Wasserstein距离 需要权重裁剪 训练稳定
WGAN-GP + 梯度惩罚 权重裁剪限制容量 更好的Lipschitz约束
LSGAN 最小二乘损失 避免sigmoid饱和 更平滑的梯度
BEGAN BEGAN损失 自编码器判别器 自动平衡
SAGAN Hinge Loss + 注意力 长距离依赖 更好的细节

5.2 WGAN-GP实现

class WGANGPTrainer(GANTrainer):
    """
    WGAN-GP:使用梯度惩罚替代权重裁剪
    """
    def __init__(self, G, D, device='cuda', lambda_gp=10):
        super().__init__(G, D, device)
        self.lambda_gp = lambda_gp
        
        self.g_optimizer = optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
        self.d_optimizer = optim.Adam(D.parameters(), lr=0.0001, betas=(0.0, 0.9))
    
    def compute_gradient_penalty(self, real_images, fake_images):
        """
        计算梯度惩罚:保证判别器的梯度范数接近1
        """
        batch_size = real_images.size(0)
        
        # 随机插值
        alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
        interpolates = alpha * real_images + (1 - alpha) * fake_images
        interpolates.requires_grad_(True)
        
        # 判别器对插值的输出
        d_interpolates = self.D(interpolates)
        
        # 计算梯度
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # 梯度范数
        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)
        
        # 惩罚项:(||grad|| - 1)^2
        penalty = ((gradient_norm - 1) ** 2).mean()
        
        return penalty
    
    def train_step(self, real_images):
        batch_size = real_images.size(0)
        
        # 训练判别器
        for _ in range(5):
            self.d_optimizer.zero_grad()
            
            real_pred = self.D(real_images)
            
            z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
            fake_images = self.G(z).detach()
            fake_pred = self.D(fake_images)
            
            # Wasserstein距离
            d_loss = -(real_pred.mean() - fake_pred.mean())
            
            # 梯度惩罚
            gp = self.compute_gradient_penalty(real_images, fake_images)
            d_loss += self.lambda_gp * gp
            
            d_loss.backward()
            self.d_optimizer.step()
        
        # 训练生成器
        self.g_optimizer.zero_grad()
        
        z = torch.randn(batch_size, self.G.noise_dim, device=self.device)
        fake_images = self.G(z)
        fake_pred = self.D(fake_images)
        
        g_loss = -fake_pred.mean()
        g_loss.backward()
        self.g_optimizer.step()
        
        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'gp': gp.item()
        }

6. StyleGAN:控制生成风格

6.1 核心创新

问题:传统GAN的噪声输入直接控制全局,无法精细控制生成图像的风格(如头发颜色、脸型)。

StyleGAN解决:引入映射网络自适应实例归一化(AdaIN),将噪声映射为风格向量,在不同层注入不同风格。

class MappingNetwork(nn.Module):
    """
    映射网络:将随机噪声z映射为中间潜在空间w
    实现风格解耦
    """
    def __init__(self, z_dim=512, w_dim=512, num_layers=8):
        super().__init__()
        
        layers = []
        for i in range(num_layers):
            in_dim = z_dim if i == 0 else w_dim
            layers.extend([
                nn.Linear(in_dim, w_dim),
                nn.LeakyReLU(0.2)
            ])
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, z):
        # 归一化输入(防止极端值)
        z = z / (z.norm(dim=1, keepdim=True) + 1e-8)
        return self.net(z)

class AdaIN(nn.Module):
    """
    自适应实例归一化:用风格向量控制生成
    """
    def __init__(self, num_features, w_dim=512):
        super().__init__()
        
        # 从w生成缩放和偏移
        self.style_scale = nn.Linear(w_dim, num_features)
        self.style_bias = nn.Linear(w_dim, num_features)
    
    def forward(self, x, w):
        # x: [N, C, H, W]
        # w: [N, w_dim] 风格向量
        
        # 实例归一化
        x_norm = nn.functional.instance_norm(x)
        
        # 风格化
        gamma = self.style_scale(w).unsqueeze(2).unsqueeze(3)  # [N, C, 1, 1]
        beta = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        
        return gamma * x_norm + beta

class StyleGANGenerator(nn.Module):
    """
    StyleGAN生成器简化版
    """
    def __init__(self, z_dim=512, w_dim=512, image_size=64):
        super().__init__()
        
        self.mapping = MappingNetwork(z_dim, w_dim)
        
        # 从常量输入开始(不是噪声)
        self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4))
        
        # 逐层上采样,每层注入不同风格
        self.blocks = nn.ModuleList()
        in_channels = 512
        
        # 4x4 → 8x8 → 16x16 → 32x32 → 64x64
        for i in range(int(np.log2(image_size)) - 2):
            out_channels = min(512, 512 // (2 ** i))
            
            self.blocks.append(nn.ModuleDict({
                'conv1': nn.Conv2d(in_channels, out_channels, 3, padding=1),
                'adain1': AdaIN(out_channels, w_dim),
                'conv2': nn.Conv2d(out_channels, out_channels, 3, padding=1),
                'adain2': AdaIN(out_channels, w_dim),
                'upsample': nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
            }))
            
            in_channels = out_channels
        
        self.to_rgb = nn.Conv2d(in_channels, 3, 1)
        self.activation = nn.Tanh()
    
    def forward(self, z, truncation_psi=0.7):
        # 映射到w空间
        w = self.mapping(z)
        
        # 截断技巧(控制多样性-质量权衡)
        if truncation_psi < 1.0:
            w = self.w_avg + truncation_psi * (w - self.w_avg)
        
        # 从常量开始
        x = self.const_input.repeat(z.size(0), 1, 1, 1)
        
        # 逐层注入风格
        for block in self.blocks:
            x = block['adain1'](x, w)
            x = torch.relu(block['conv1'](x))
            x = block['adain2'](x, w)
            x = torch.relu(block['conv2'](x))
            x = block['upsample'](x)
        
        return self.activation(self.to_rgb(x))

# 风格混合:用两个w控制不同层
def style_mixing(generator, z1, z2, layer_idx=4):
    """
    用z1控制低层(姿态、脸型),z2控制高层(颜色、纹理)
    """
    w1 = generator.mapping(z1)
    w2 = generator.mapping(z2)
    
    # 前layer_idx层用w1,后面用w2
    # 实际实现需要修改forward,传入多个w
    pass

7. 实战:生成手写数字

7.1 完整训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1, 1]
])

mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)

# 为MNIST设计的轻量DCGAN(28x28灰度图)
class MNISTGenerator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.noise_dim = noise_dim
        
        self.net = nn.Sequential(
            # [N, 100, 1, 1] → [N, 128, 7, 7]
            nn.ConvTranspose2d(noise_dim, 128, 7, 1, 0),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # → [N, 64, 14, 14]
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # → [N, 1, 28, 28]
            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh()
        )
    
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.net(z)

class MNISTDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.net = nn.Sequential(
            # [N, 1, 28, 28] → [N, 64, 14, 14]
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            
            # → [N, 128, 7, 7]
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            
            # → [N, 1, 1, 1]
            nn.Conv2d(128, 1, 7, 1, 0),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x).squeeze()

# 训练
def train_mnist_gan(epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    G = MNISTGenerator().to(device)
    D = MNISTDiscriminator().to(device)
    
    g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    criterion = nn.BCELoss()
    
    # 记录生成图像
    fixed_noise = torch.randn(64, 100, device=device)
    
    for epoch in range(epochs):
        for batch_idx, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            
            # 真实标签和假标签(带噪声,防止过自信)
            real_labels = torch.ones(batch_size, device=device) * 0.9
            fake_labels = torch.zeros(batch_size, device=device) + 0.1
            
            # ========== 训练判别器 ==========
            d_optimizer.zero_grad()
            
            # 真实数据
            real_output = D(real_images)
            d_loss_real = criterion(real_output, real_labels)
            
            # 生成数据
            noise = torch.randn(batch_size, 100, device=device)
            fake_images = G(noise).detach()
            fake_output = D(fake_images)
            d_loss_fake = criterion(fake_output, fake_labels)
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # ========== 训练生成器 ==========
            g_optimizer.zero_grad()
            
            noise = torch.randn(batch_size, 100, device=device)
            fake_images = G(noise)
            fake_output = D(fake_images)
            
            # 非饱和损失
            g_loss = criterion(fake_output, real_labels)  # 目标是骗过D
            
            g_loss.backward()
            g_optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}: "
                      f"D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
        
        # 每个epoch保存生成图像
        with torch.no_grad():
            fake_samples = G(fixed_noise).cpu()
            
            fig, axes = plt.subplots(8, 8, figsize=(10, 10))
            for idx, ax in enumerate(axes.flat):
                ax.imshow(fake_samples[idx].squeeze(), cmap='gray', vmin=-1, vmax=1)
                ax.axis('off')
            
            plt.suptitle(f'Epoch {epoch}', fontsize=14)
            plt.tight_layout()
            plt.savefig(f'gan_samples/epoch_{epoch:03d}.png')
            plt.close()
    
    return G

# 运行
# G = train_mnist_gan(epochs=50)

7.2 插值实验:潜在空间的平滑性

def latent_space_interpolation(G, device='cuda', num_steps=10):
    """
    在潜在空间中插值,观察生成图像的平滑变化
    """
    G.eval()
    
    # 两个随机噪声向量
    z1 = torch.randn(1, 100, device=device)
    z2 = torch.randn(1, 100, device=device)
    
    # 线性插值
    alphas = torch.linspace(0, 1, num_steps)
    
    fig, axes = plt.subplots(1, num_steps, figsize=(20, 2))
    
    with torch.no_grad():
        for idx, alpha in enumerate(alphas):
            z = alpha * z1 + (1 - alpha) * z2
            img = G(z).cpu().squeeze()
            
            axes[idx].imshow(img, cmap='gray', vmin=-1, vmax=1)
            axes[idx].set_title(f'α={alpha:.2f}')
            axes[idx].axis('off')
    
    plt.suptitle('Latent Space Interpolation')
    plt.tight_layout()
    plt.show()

# latent_space_interpolation(G)

8. 实战:人脸动漫化

8.1 CycleGAN:无配对图像翻译

class ResidualBlock(nn.Module):
    """CycleGAN的残差块"""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.in1 = nn.InstanceNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.in2 = nn.InstanceNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        return out + residual

class Generator(nn.Module):
    """
    CycleGAN生成器:编码器 → 转换 → 解码器
    """
    def __init__(self, in_channels=3, num_residual=9):
        super().__init__()
        
        # 下采样编码
        self.down = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, padding=3, padding_mode='reflect'),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(True)
        )
        
        # 残差转换
        self.residuals = nn.Sequential(*[
            ResidualBlock(256) for _ in range(num_residual)
        ])
        
        # 上采样解码
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            
            nn.Conv2d(64, in_channels, 7, padding=3, padding_mode='reflect'),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.down(x)
        x = self.residuals(x)
        x = self.up(x)
        return x

class CycleGAN(nn.Module):
    """
    CycleGAN:两个生成器 + 两个判别器
    G: X→Y (真人→动漫)
    F: Y→X (动漫→真人)
    Dx: 判别X是否真实
    Dy: 判别Y是否真实
    """
    def __init__(self):
        super().__init__()
        
        self.G = Generator()  # 真人 → 动漫
        self.F = Generator()  # 动漫 → 真人
        
        self.Dx = Discriminator()  # 判别真人
        self.Dy = Discriminator()  # 判别动漫
    
    def compute_cycle_loss(self, real_x, real_y):
        """
        循环一致性损失:F(G(x)) ≈ x, G(F(y)) ≈ y
        """
        fake_y = self.G(real_x)
        rec_x = self.F(fake_y)
        
        fake_x = self.F(real_y)
        rec_y = self.G(fake_x)
        
        cycle_loss = (torch.abs(rec_x - real_x).mean() + 
                     torch.abs(rec_y - real_y).mean())
        
        return cycle_loss
    
    def compute_identity_loss(self, real_x, real_y):
        """
        身份损失:G(y) ≈ y, F(x) ≈ x(防止颜色变化)
        """
        idt_y = self.G(real_y)
        idt_x = self.F(real_x)
        
        identity_loss = (torch.abs(idt_y - real_y).mean() +
                        torch.abs(idt_x - real_x).mean())
        
        return identity_loss

class Discriminator(nn.Module):
    """
    PatchGAN判别器:判断每个patch的真假
    """
    def __init__(self, in_channels=3):
        super().__init__()
        
        def discriminator_block(in_c, out_c, normalize=True):
            layers = [nn.Conv2d(in_c, out_c, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)  # 输出patch的分数
        )
    
    def forward(self, x):
        return self.model(x)

# CycleGAN训练(简化)
def train_cyclegan(dataloader_x, dataloader_y, epochs=200):
    device = torch.device('cuda')
    
    model = CycleGAN().to(device)
    
    g_optimizer = optim.Adam(list(model.G.parameters()) + list(model.F.parameters()), 
                            lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(list(model.Dx.parameters()) + list(model.Dy.parameters()),
                            lr=0.0002, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for (real_x, _), (real_y, _) in zip(dataloader_x, dataloader_y):
            real_x, real_y = real_x.to(device), real_y.to(device)
            
            # 生成
            fake_y = model.G(real_x)
            fake_x = model.F(real_y)
            
            # 循环重建
            rec_x = model.F(fake_y)
            rec_y = model.G(fake_x)
            
            # ========== 判别器损失 ==========
            d_x_real = model.Dx(real_x)
            d_x_fake = model.Dx(fake_x.detach())
            d_y_real = model.Dy(real_y)
            d_y_fake = model.Dy(fake_y.detach())
            
            d_loss = ((d_x_real - 1)**2).mean() + (d_x_fake**2).mean() + \
                     ((d_y_real - 1)**2).mean() + (d_y_fake**2).mean()
            
            # ========== 生成器损失 ==========
            g_loss_gan = ((model.Dy(fake_y) - 1)**2).mean() + \
                        ((model.Dx(fake_x) - 1)**2).mean()
            
            cycle_loss = model.compute_cycle_loss(real_x, real_y)
            identity_loss = model.compute_identity_loss(real_x, real_y)
            
            g_loss = g_loss_gan + 10 * cycle_loss + 5 * identity_loss
            
            # 更新
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()
            
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
        
        print(f"Epoch {epoch}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
        
        # 保存样本
        if epoch % 10 == 0:
            with torch.no_grad():
                sample = model.G(real_x[:8])
                # 保存可视化...

9. 小结

知识点 核心要点
GAN原理 生成器 vs 判别器的Minimax博弈
生成器 噪声→图像,转置卷积上采样
判别器 图像→概率,卷积下采样
DCGAN 卷积GAN,BatchNorm稳定训练
模式坍塌 生成器只生成少数模式
解决坍塌 标签平滑、非饱和损失、WGAN、梯度惩罚
WGAN Wasserstein距离,训练更稳定
StyleGAN 映射网络+AdaIN,精细控制风格
CycleGAN 循环一致性,无配对数据翻译
应用 数据增强、风格迁移、超分辨率、图像修复

课后练习

  1. 损失对比:在MNIST上对比原始GAN、WGAN、WGAN-GP的损失曲线和生成质量。

  2. 插值可视化:训练MNIST GAN后,在潜在空间做2D网格插值,观察数字的连续变化。

  3. 条件GAN:实现CGAN(Conditional GAN),输入类别标签+噪声,控制生成特定数字。

  4. 模式坍塌检测:实现生成样本的多样性度量(如FID分数),监控训练过程。

  5. 挑战:实现StyleGAN2-ADA,用自适应判别器增强(ADA)训练自定义数据集。


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐