import cv2
import numpy as np
import torch
from PIL import Image

class GridMask(object):
    def __init__(
        self,
        use_h=True,
        use_w=True,
        rotate=1,
        offset=False,
        ratio=0.5,
        mode=0,
        prob=1.0,
    ):
        self.use_h = use_h
        self.use_w = use_w
        self.rotate = rotate
        self.offset = offset
        self.ratio = ratio
        self.mode = mode
        self.prob = prob

    def __call__(self, img):
        """
        img:
            torch.Tensor [C,H,W]
        """

        if np.random.rand() > self.prob:
            return img

        c, h, w = img.shape

        hh = int(1.5 * h)
        ww = int(1.5 * w)

        #d = np.random.randint(2, h)
        d = np.random.randint(50,80)

        l = min(max(int(d * self.ratio + 0.5), 1), d - 1)

        mask = np.ones((hh, ww), np.float32)

        st_h = np.random.randint(d)
        st_w = np.random.randint(d)

        # 横向mask
        if self.use_h:
            for i in range(hh // d):
                s = d * i + st_h
                t = min(s + l, hh)
                mask[s:t, :] *= 0

        # 纵向mask
        if self.use_w:
            for i in range(ww // d):
                s = d * i + st_w
                t = min(s + l, ww)
                mask[:, s:t] *= 0

        # rotate
        r = np.random.randint(self.rotate)

        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)

        mask = np.asarray(mask)

        # crop回原图大小
        mask = mask[
            (hh - h) // 2 : (hh - h) // 2 + h,
            (ww - w) // 2 : (ww - w) // 2 + w,
        ]

        mask = torch.from_numpy(mask).float().to(img.device)

        if self.mode == 1:
            mask = 1 - mask

        # [H,W] -> [C,H,W]
        mask = mask.unsqueeze(0).expand_as(img)

        if self.offset:
            offset = (
                torch.from_numpy(
                    2 * (np.random.rand(h, w) - 0.5)
                )
                .float()
                .to(img.device)
            )

            offset = offset.unsqueeze(0).expand_as(img)

            img = img * mask + offset * (1 - mask)

        else:
            img = img * mask

        return img

def visualize_gridmask_opencv_only(image_path, gridmask, save_augmented_path=None):
    # 使用 OpenCV 读取图片
    img_bgr = cv2.imread(image_path)
    if img_bgr is None:
        print(f"Error: Cannot read image from {image_path}")
        return None
    
    img_bgr = cv2.resize(img_bgr, (384,128))
    
    # 归一化到 [0, 1]
    img_normalized = img_bgr.astype(np.float32) / 255.0
    
    # OpenCV 是 HWC BGR 格式,需要转换为 CHW RGB 格式
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_normalized_rgb = img_rgb.astype(np.float32) / 255.0
    
    # 转换为 tensor [C, H, W]
    img_tensor = torch.from_numpy(img_normalized_rgb).permute(2, 0, 1)
    
    # 应用 GridMask
    augmented_tensor = gridmask(img_tensor)
    
    # 转换回 numpy [H, W, C]
    augmented_rgb = augmented_tensor.permute(1, 2, 0).cpu().numpy()
    augmented_rgb = np.clip(augmented_rgb, 0, 1)
    
    # 转回 0-255 范围并转为 BGR 格式用于 OpenCV 显示
    augmented_uint8_rgb = (augmented_rgb * 255).astype(np.uint8)
    augmented_bgr = cv2.cvtColor(augmented_uint8_rgb, cv2.COLOR_RGB2BGR)
    
    # 水平拼接原图和增强图(宽度相同,高度相同)
    # 如果图片太大,可以调整大小
    h, w = img_bgr.shape[:2]
    if max(h, w) > 800:
        scale = 800 / max(h, w)
        new_w = int(w * scale)
        new_h = int(h * scale)
        img_resized = cv2.resize(img_bgr, (new_w, new_h))
        aug_resized = cv2.resize(augmented_bgr, (new_w, new_h))
    else:
        img_resized = img_bgr
        aug_resized = augmented_bgr
    
    # 水平拼接
    comparison = np.hstack((img_resized, aug_resized))
    
    # 添加文字标注
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.7
    thickness = 2
    text_color = (0, 255, 0)  # 绿色
    
    # 在左侧图片上方添加 "Original"
    cv2.putText(comparison, "Original", (10, 30), font, font_scale, text_color, thickness)
    # 在右侧图片上方添加 "GridMask"
    cv2.putText(comparison, "GridMask", (comparison.shape[1]//2 + 10, 30), font, font_scale, text_color, thickness)
    
    # 显示图像
    cv2.imshow('GridMask Augmentation Comparison', comparison)
    
    # 保存增强后的图片
    if save_augmented_path:
        cv2.imwrite(save_augmented_path, augmented_bgr)
        print(f"Augmented image saved to {save_augmented_path}")
    
    print("Press any key to close the window...")
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    
    return augmented_bgr

# 创建 GridMask 实例
gridmask = GridMask(
    use_h=True,
    use_w=True,
    rotate=4,
    offset=False,
    ratio=0.65,
    mode=1,
    prob=1,  # 70% 概率应用增强
)

# 调用可视化(替换为你的图片路径)
image_path = "test.jpg"  # 修改为你的图片路径
augmented = visualize_gridmask_opencv_only(image_path, gridmask, save_augmented_path="augmented.jpg")

这是一个经典的数据增强模块 GridMask,常用于目标检测、BEV、分类等视觉任务。
它的核心思想:
随机用“网格状”的遮挡去盖住图片的一部分,迫使模型学习更鲁棒的特征。
类似:
Cutout(随机挖洞)
Random Erasing
DropBlock
但 GridMask 是:
“规则网格”遮挡,而不是随机矩形。

set_prob()作用:
训练前期弱增强,
后期逐渐增强。
epoch=0 -> prob=0
epoch=50 -> prob=0.5
epoch=100 -> prob=1
属于 curriculum augmentation。

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import cv2


class GridMask(nn.Module):
    def __init__(
        self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
    ):
        super(GridMask, self).__init__()
        self.use_h = use_h
        self.use_w = use_w
        self.rotate = rotate
        self.offset = offset
        self.ratio = ratio
        self.mode = mode
        self.st_prob = prob
        self.prob = prob

    def set_prob(self, epoch, max_epoch):
        self.prob = self.st_prob * epoch / max_epoch

    def forward(self, x):
        if np.random.rand() > self.prob or not self.training:
            return x

        n, c, h, w = x.size()

        x = x.view(-1, h, w)

        hh = int(1.5 * h)
        ww = int(1.5 * w)

        d = np.random.randint(2, h)

        self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)

        mask = np.ones((hh, ww), np.float32)

        st_h = np.random.randint(d)
        st_w = np.random.randint(d)

        if self.use_h:
            for i in range(hh // d):
                s = d * i + st_h
                t = min(s + self.l, hh)
                mask[s:t, :] *= 0

        if self.use_w:
            for i in range(ww // d):
                s = d * i + st_w
                t = min(s + self.l, ww)
                mask[:, s:t] *= 0

        r = np.random.randint(self.rotate)

        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)
        mask = np.asarray(mask)

        mask = mask[
            (hh - h) // 2 : (hh - h) // 2 + h,
            (ww - w) // 2 : (ww - w) // 2 + w,
        ]

        device = x.device

        mask = torch.from_numpy(mask).float().to(device)

        if self.mode == 1:
            mask = 1 - mask

        mask = mask.expand_as(x)

        if self.offset:
            offset = (
                torch.from_numpy(2 * (np.random.rand(h, w) - 0.5))
                .float()
                .to(device)
            )
            x = x * mask + offset * (1 - mask)
        else:
            x = x * mask

        return x.view(n, c, h, w)


# =========================
# 创建 GridMask
# =========================

grid_mask = GridMask(
    True,
    True,
    rotate=10,
    offset=False,
    ratio=0.65,
    mode=1,
    prob=1.0,   # 这里改成1,保证一定触发
)

# 一定要train模式
# eval模式不会增强

grid_mask.train()


# =========================
# 读取图片
# =========================

img_path = 'test.jpg'

img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# 保存原图
img_ori = img.copy()

# 转 tensor
img_tensor = torch.from_numpy(img).float() / 255.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)

# =========================
# GridMask增强
# =========================

with torch.no_grad():
    out = grid_mask(img_tensor)

# 转回 numpy
out = out.squeeze(0).permute(1, 2, 0).numpy()

# clip
out = np.clip(out, 0, 1)


# =========================
# 可视化
# =========================

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(img_ori)
plt.title('Original')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(out)
plt.title('GridMask')
plt.axis('off')

plt.tight_layout()
plt.show()

这是一个经典的数据增强模块 GridMask,常用于目标检测、BEV、分类等视觉任务。

它的核心思想:

随机用“网格状”的遮挡去盖住图片的一部分,迫使模型学习更鲁棒的特征。

类似:

  • Cutout(随机挖洞)
  • Random Erasing
  • DropBlock

但 GridMask 是:

“规则网格”遮挡,而不是随机矩形。


整体效果

例如原图:

################
################
################
################

GridMask 后:

##..##..##..####
##..##..##..####
################
################
##..##..##..####

形成周期性遮挡。


代码整体流程

forward 流程:

输入图像
  ↓
随机生成网格mask
  ↓
随机旋转
  ↓
裁剪回原图大小
  ↓
mask乘到图像
  ↓
输出增强图

1. 初始化

def __init__(
    self,
    use_h,
    use_w,
    rotate=1,
    offset=False,
    ratio=0.5,
    mode=0,
    prob=1.0
)

参数解释


use_h

是否沿高度方向遮挡

self.use_h = use_h

例如:

True

会生成横条遮挡。



use_w

是否沿宽度方向遮挡

self.use_w = use_w

例如:

True

会生成竖条遮挡。


如果:

use_h=True
use_w=True

就形成网格。


rotate

mask随机旋转角度范围

r = np.random.randint(self.rotate)

例如:

rotate=10

表示:

0~9°

随机旋转。


offset

是否用随机值填充被遮挡区域

默认:

False

即:

遮挡区域 = 0

如果:

True

则:

遮挡区域 = 随机噪声

ratio

遮挡比例

self.l = d * ratio

例如:

d = 100
ratio = 0.5

则:

遮挡宽度 = 50

mode

mask翻转模式

mode=0

正常:

0 -> 被遮挡
1 -> 保留

如果:

mode=1

则反转。


prob

应用增强概率

if np.random.rand() > self.prob:
    return x

2. set_prob()

def set_prob(self, epoch, max_epoch):
    self.prob = self.st_prob * epoch / max_epoch

作用:

训练前期弱增强,
后期逐渐增强。

例如:

st_prob = 1.0

训练:

epoch=0   -> prob=0
epoch=50  -> prob=0.5
epoch=100 -> prob=1

属于 curriculum augmentation。


3. forward()


Step1 判断是否增强

if np.random.rand() > self.prob or not self.training:
    return x

推理阶段不增强。


Step2 获取shape

n, c, h, w = x.size()

例如:

[8, 3, 256, 256]

Step3 reshape

x = x.view(-1, h, w)

变成:

[8*3, 256, 256]

方便mask broadcast。


Step4 扩大mask尺寸

hh = int(1.5 * h)
ww = int(1.5 * w)

为什么?

因为后面会旋转。

旋转后角落会缺失。

扩大后再中心裁剪。

这是经典 trick。


Step5 随机网格间距 d

d = np.random.randint(2, h)

例如:

d = 80

表示:

每80像素一个周期

Step6 计算遮挡宽度

self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)

即:

l = d * ratio

并保证:

1 <= l < d

例如:

d=80
ratio=0.5

得到:

l=40

Step7 初始化mask

mask = np.ones((hh, ww), np.float32)

开始全部为1:

111111111
111111111

Step8 横向遮挡

for i in range(hh // d):
    s = d * i + st_h
    t = min(s + self.l, hh)
    mask[s:t, :] *= 0

例如:

11111111
00000000
11111111
00000000

Step9 纵向遮挡

mask[:, s:t] *= 0

形成网格:

11001100
11001100
00000000
00000000
11001100

Step10 随机旋转

mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)

避免:

总是水平垂直

提升随机性。


Step11 中心裁剪

mask = mask[
    (hh - h) // 2 : (hh - h) // 2 + h,
    (ww - w) // 2 : (ww - w) // 2 + w,
]

恢复原图大小。


Step12 转tensor

mask = torch.from_numpy(mask).float().cuda()

shape:

[h,w]

Step13 mode翻转

if self.mode == 1:
    mask = 1 - mask

Step14 broadcast

mask = mask.expand_as(x)

扩展成:

[n*c, h, w]

Step15 应用mask


普通模式

x = x * mask

遮挡部分变0。


offset模式

x = x * mask + offset * (1 - mask)

被遮挡区域填随机噪声。

类似:

CutMix + Noise

Step16 reshape回去

return x.view(n, c, h, w)

恢复:

[B,C,H,W]

GridMask核心思想

本质:

强迫模型:
不能依赖局部纹理
必须学习全局特征

特别适合:

  • 检测
  • 分割
  • BEV
  • 自动驾驶

因为:

现实中:

  • 遮挡
  • 雨雾
  • 行人遮挡
  • 传感器缺失

都很多。


为什么 BEVDepth 里经常用?

BEV任务:

camera → depth → BEV

非常容易:

过拟合局部纹理

GridMask 能:

  • 提升泛化
  • 提升鲁棒性
  • 防止依赖单摄像头区域

所以:

很多 BEV 系列:

  • BEVDet
  • BEVDepth
  • PETR
  • BEVFormer

都会默认带 GridMask。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐