paddle.vision 与 torchvision 中的box NMS使用方式
vision
pytorch/vision: 一个基于 PyTorch 的计算机视觉库,提供了各种计算机视觉算法和工具,适合用于实现计算机视觉应用程序。
项目地址:https://gitcode.com/gh_mirrors/vi/vision
免费下载资源
·
torchvision 中有多个用于计算 BBox NMS 的 API, 在本篇氵文中, 使用
torchvision.ops.boxes.batched_nms
paddle.vision 中通过 paddle.vision.ops.nms
来进行多个 Box 的 NMS 操作
1. torchvision 中 batched_nms 操作
torchvision batched_nms
def batched_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
idxs: torch.Tensor,
iou_threshold: float,
) -> torch.Tensor
传入的参数分别为
- 边界框
boxes
, 格式[x1, y1, x2, y2]
,shape 为[num, 4]
,dtype 为float
- 置信度
scores
, shape 为[num,]
,dtype 为float
- 类别
idxs
, shape 为[num,]
,dtype 为int
来举个例子:
import numpy as np
import torch
from torchvision.ops import boxes as box_ops
seed = 1107
iou_threshold = 0.35
box_num = 100000
cls_num = 80
np.random.seed(seed)
boxes = np.random.rand(box_num, 4).astype("float32")
boxes = torch.from_numpy(boxes)
scores = np.random.rand(box_num).astype("float32")
scores = torch.from_numpy(scores)
idxs = np.random.randint(0, cls_num, size=(box_num,))
idxs = torch.from_numpy(idxs)
assert boxes.shape[-1] == 4
keep = box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
2. paddle.vision.ops.nms 操作
paddle.vision.ops.nms(
boxes,
iou_threshold=0.3,
scores=None,
category_idxs=None,
categories=None,
top_k=None)
boxes
、iou_threshold
、scores
和 category_idxs
等参数和上述 torchvision 中 batched_nms
参数一样
不同的是 paddle 中还需要 categories
参数,(其实没什么必要)
category_idxs
是每个 bbox 的类别,而 categories
是一共的类别
比如 COCO 一共80类,则:
categories = paddle.arange(80)
Paddle 中的例子:
import numpy as np
import paddle
seed = 1107
iou_threshold = 0.35
box_num = 100000
cls_num = 80
np.random.seed(seed)
boxes = np.random.rand(box_num, 4).astype("float32")
boxes = paddle.to_tensor(boxes)
scores = np.random.rand(box_num).astype("float32")
scores = paddle.to_tensor(scores)
idxs = np.random.randint(0, cls_num, size=(box_num,))
idxs = paddle.to_tensor(idxs)
cls_list = paddle.arange(0, cls_num)
assert boxes.shape[-1] == 4
keep = paddle.vision.ops.nms(boxes, iou_threshold, scores, idxs, cls_list)
GitHub 加速计划 / vi / vision
15.85 K
6.89 K
下载
pytorch/vision: 一个基于 PyTorch 的计算机视觉库,提供了各种计算机视觉算法和工具,适合用于实现计算机视觉应用程序。
最近提交(Master分支:3 个月前 )
518ee93d
3 天前
7d077f13
5 天前
更多推荐
已为社区贡献14条内容
所有评论(0)