[论文学习]Manifold Mixup和PatchUp的代码重新实现(实现即插即用且速度更快)
Manifold Mixup和PatchUp是对mixup数据增强算法的两种改进方法,作者都来自Yoshua Bengio团队。这两种方法都是mixup方法在中间隐层的推广,因此原文开源代码都需要对网络各层的内部代码进行修改,使用起来并不方便,不能做到即插即用。我用pytorch中的钩子方法(hook)对这两个方法进行重新实现,这样就可以实现即插即用,方便的应用到各种网络结构中,而且我实现的代码比原开源代码速度还能提高60%左右。
Manifold Mixup 论文:https://arxiv.org/abs/1806.05236
Manifold Mixup 官方开源:https://github.com/vikasverma1077/manifold_mixup
PatchUp 论文:https://arxiv.org/abs/2006.07794
PatchUp 官方开源:https://github.com/chandar-lab/PatchUp
一、Manifold Mixup简介及代码
manifold mixup是对mixup的扩展,把输入数据(raw input data)混合扩展到对中间隐层输出混合。至于对中间隐层混合更有效的原因,作者的解释比较深奥。首先给出了现象级的解释,即这种混合带来了三个优势:平滑决策边界、拉大低置信空间(拉开各类别高置信空间的间距)、展平隐层输出的数值。至于这三点为什么有效,从作者说法看这应该是一种业界共识。然后作者又从数学上分析了第三点,即为什么manifold mixup可以实现展平中间隐层输出。
由于需要修改网络中间层的输出张量,如果不修改网络内部,也可以使用钩子操作(hook)在外部进行。核心部分代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def to_one_hot(inp, num_classes):
y_onehot = torch.FloatTensor(inp.size(0), num_classes).to(inp.device)
y_onehot.zero_()
y_onehot.scatter_(1, inp.unsqueeze(1).data, 1)
return y_onehot
bce_loss = nn.BCELoss()
softmax = nn.Softmax(dim=1)
class ManifoldMixupModel(nn.Module):
def __init__(self, model, num_classes = 10, alpha = 1):
super().__init__()
self.model = model
self.alpha = alpha
self.lam = None
self.num_classes = num_classes
##选择需要操作的层,在ResNet中各block的层名为layer1,layer2...所以可以写成如下。其他网络请自行修改
self. module_list = []
for n,m in self.model.named_modules():
#if 'conv' in n:
if n[:-1]=='layer':
self.module_list.append(m)
def forward(self, x, target=None):
if target==None:
out = self.model(x)
return out
else:
if self.alpha <= 0:
self.lam = 1
else:
self.lam = np.random.beta(self.alpha, self.alpha)
k = np.random.randint(-1, len(self.module_list))
self.indices = torch.randperm(target.size(0)).cuda()
target_onehot = to_one_hot(target, self.num_classes)
target_shuffled_onehot = target_onehot[self.indices]
if k == -1:
x = x * self.lam + x[self.indices] * (1 - self.lam)
out = self.model(x)
else:
modifier_hook = self.module_list[k].register_forward_hook(self.hook_modify)
out = self.model(x)
modifier_hook.remove()
target_reweighted = target_onehot* self.lam + target_shuffled_onehot * (1 - self.lam)
loss = bce_loss(softmax(out), target_reweighted)
return out, loss
def hook_modify(self, module, input, output):
output = self.lam * output + (1 - self.lam) * output[self.indices]
return output
调用代码如下:
net = ResNet18()
net = ManifoldMixupModel(net,num_classes=10, alpha=args.alpha)
def train(epoch):
net.train()
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs, loss = net(inputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test(epoch):
net.eval()
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
二、PatchUp简介及代码
PatchUp方法在manifold mixup基础上,又借鉴了cutMix在空间维度剪裁的思路,对中间隐层输出也进行剪裁,对两个不同样本的中间隐层剪裁块(patches)进行互换或插值,文中称互换法为硬patchUp,插值法为软patchUp。试验发现互换法在识别精度上更好,插值法在对抗攻击的鲁棒性上更好。这篇论文中没有对方法理论进行深度解释,仅仅给出了一个现象级对比,就是patchUp方法的隐层激活值比较高。
使用hook实现的核心代码PatchUpModel类如下,注意在该代码中强制k=-1就可以变成CutMix:
class PatchUpModel(nn.Module):
def __init__(self, model, num_classes = 10, block_size=7, gamma=.9, patchup_type='hard',keep_prob=.9):
super().__init__()
self.patchup_type = patchup_type
self.block_size = block_size
self.gamma = gamma
self.gamma_adj = None
self.kernel_size = (block_size, block_size)
self.stride = (1, 1)
self.padding = (block_size // 2, block_size // 2)
self.computed_lam = None
self.model = model
self.num_classes = num_classes
self. module_list = []
for n,m in self.model.named_modules():
if n[:-1]=='layer':
#if 'conv' in n:
self.module_list.append(m)
def adjust_gamma(self, x):
return self.gamma * x.shape[-1] ** 2 / \
(self.block_size ** 2 * (x.shape[-1] - self.block_size + 1) ** 2)
def forward(self, x, target=None):
if target==None:
out = self.model(x)
return out
else:
self.lam = np.random.beta(2.0, 2.0)
k = np.random.randint(-1, len(self.module_list))
self.indices = torch.randperm(target.size(0)).cuda()
self.target_onehot = to_one_hot(target, self.num_classes)
self.target_shuffled_onehot = self.target_onehot[self.indices]
if k == -1: #CutMix
W,H = x.size(2),x.size(3)
cut_rat = np.sqrt(1. - self.lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
x[:, :, bbx1:bbx2, bby1:bby2] = x[self.indices, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
out = self.model(x)
loss = bce_loss(softmax(out), self.target_onehot) * lam +\
bce_loss(softmax(out), self.target_shuffled_onehot) * (1. - lam)
else:
modifier_hook = self.module_list[k].register_forward_hook(self.hook_modify)
out = self.model(x)
modifier_hook.remove()
loss = 1.0 * bce_loss(softmax(out), self.target_a) * self.total_unchanged_portion + \
bce_loss(softmax(out), self.target_b) * (1. - self.total_unchanged_portion) + \
1.0 * bce_loss(softmax(out), self.target_reweighted)
return out, loss
def hook_modify(self, module, input, output):
self.gamma_adj = self.adjust_gamma(output)
p = torch.ones_like(output[0]) * self.gamma_adj
m_i_j = torch.bernoulli(p)
mask_shape = len(m_i_j.shape)
m_i_j = m_i_j.expand(output.size(0), m_i_j.size(0), m_i_j.size(1), m_i_j.size(2))
holes = F.max_pool2d(m_i_j, self.kernel_size, self.stride, self.padding)
mask = 1 - holes
unchanged = mask * output
if mask_shape == 1:
total_feats = output.size(1)
else:
total_feats = output.size(1) * (output.size(2) ** 2)
total_changed_pixels = holes[0].sum()
total_changed_portion = total_changed_pixels / total_feats
self.total_unchanged_portion = (total_feats - total_changed_pixels) / total_feats
if self.patchup_type == 'hard':
self.target_reweighted = self.total_unchanged_portion * self.target_onehot +\
total_changed_portion * self.target_shuffled_onehot
patches = holes * output[self.indices]
self.target_b = self.target_onehot[self.indices]
elif self.patchup_type == 'soft':
self.target_reweighted = self.total_unchanged_portion * self.target_onehot +\
self.lam * total_changed_portion * self.target_onehot +\
(1 - self.lam) * total_changed_portion * self.target_shuffled_onehot
patches = holes * output
patches = patches * self.lam + patches[self.indices] * (1 - self.lam)
self.target_b = self.lam * self.target_onehot + (1 - self.lam) * self.target_shuffled_onehot
else:
raise ValueError("patchup_type must be \'hard\' or \'soft\'.")
output = unchanged + patches
self.target_a = self.target_onehot
return output
调用过程同上,其中模型包装语句如下:
net = ResNet18()
net = PatchUpModel(net,num_classes=10, block_size=7, gamma=.9, patchup_type='hard')
三、在CIFAR-10上试验结果
试验主要目的是验证代码可运行。仅靠在一个简单数据集上一次试验非常不充分,不能公平对比效果,所以不作为各方法的性能对比。
更多推荐
所有评论(0)