1 背景

onnx模型推理单张图片,网上的教程非常多,我自己以前也写了很多这些内容,但如何推理整个数据集来验证精度呢?

如果你只是为了验证导出的onnx模型精度如何,可以参考这篇文章。

为了保证模型前后处理完全一致,前后处理都直接复用原本的代码,输入输出数据涉及到tensor和numpy转换时直接用torch.from_numpy和.numpy实现。

到嵌入式开发板上跑的话,前后处理都是需要自己写的,而且无法依赖torch。

2 评测Imagenet数据集

imagenet 验证集val,内部有1000个文件夹,每个文件夹下对应有50张图片。
pytorch默认使用PIL读取,刚读取的图片,像素顺序RGB,layout:NHWC
经过transforms.ToTensor(),像素顺序RGB,layout:NCHW。当然,transforms.ToTensor()还有数据归一化(除以255)的作用,具体细节可参考另一篇博客不使用torchvision.transforms 对图片预处理python实现

主程序如下,主要修改该代码即可:

import torch
import torch.nn as nn
import sys
import os
import time
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import utils.common as utils		# 下面给出代码
from tqdm import tqdm


class Data:
    def __init__(self, data_path):
        scale_size = 224

        valdir = os.path.join(data_path, 'val')
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        testset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.Resize(scale_size),
                transforms.ToTensor(),	
                normalize,
            ]))

        self.loader_test = DataLoader(
            testset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=True)

def test_onnxruntime(ort_session, testLoader, logger, topk=(1,)):
    accuracy = utils.AverageMeter('Acc@1', ':6.2f')
    top5_accuracy = utils.AverageMeter('Acc@5', ':6.2f')

    start_time = time.time()
    testLoader = tqdm(testLoader, file=sys.stdout)
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testLoader):
            inputs_origin = inputs
            inputs, targets = inputs.numpy(), targets
            ort_inputs = {"input1": inputs}
            outputs = ort_session.run(None, ort_inputs)
            outputs = torch.from_numpy(outputs[0])

            predicted = utils.accuracy(outputs, targets, topk=topk)
            accuracy.update(predicted[0], inputs_origin.size(0))
            top5_accuracy.update(predicted[1], inputs_origin.size(0))

        current_time = time.time()
        logger.info(
            'Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'
                .format(float(accuracy.avg), float(top5_accuracy.avg), (current_time - start_time))
        )

    return top5_accuracy.avg, accuracy.avg

def onnx_inference_imagenet():
    job_dir = './experiment'
    logger = utils.get_logger(os.path.join(job_dir + 'logger.log'))

    # Data
    print('==> Preparing data..')
    data_path = '/home/users/dataset/imagenet/'
    # data_path = '/data/horizon_j5/data/imagenet/'
    loader = Data(data_path)
    testLoader = loader.loader_test

    onnx_path = "./weights/resnet50/resnet50_pruned.onnx"
    #---------------------------------------------------------#
    #   使用onnxruntime
    #---------------------------------------------------------#
    import onnxruntime
    ort_session = onnxruntime.InferenceSession(onnx_path)
    #---------------------------------------------------------#
    #   进test_onnxruntime函数
    #---------------------------------------------------------#
    test_onnxruntime(ort_session, testLoader, logger, topk=(1, 5))

if __name__ == '__main__':
    onnx_inference_imagenet()

在utils文件夹下,有common.py文件,其中代码如下:

import os
import sys
import shutil
import time, datetime
import logging
import numpy as np
from PIL import Image
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils


'''record configurations'''
class record_config():
    def __init__(self, args):
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
        today = datetime.date.today()

        self.args = args
        self.job_dir = Path(args.job_dir)

        def _make_dir(path):
            if not os.path.exists(path):
                os.makedirs(path)

        _make_dir(self.job_dir)

        config_dir = self.job_dir / 'config.txt'
        #if not os.path.exists(config_dir):
        if args.resume:
            with open(config_dir, 'a') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')
        else:
            with open(config_dir, 'w') as f:
                f.write(now + '\n\n')
                for arg in vars(args):
                    f.write('{}: {}\n'.format(arg, getattr(args, arg)))
                f.write('\n')


def get_logger(file_path):

    logger = logging.getLogger('gal')
    log_format = '%(asctime)s | %(message)s'
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.setLevel(logging.INFO)

    return logger

#label smooth
class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def save_checkpoint(state, is_best, save):
    if not os.path.exists(save):
        os.makedirs(save)
    filename = os.path.join(save, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(save, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res



def progress_bar(current, total, msg=None):
    _, term_width = os.popen('stty size', 'r').read().split()
    term_width = int(term_width)

    TOTAL_BAR_LENGTH = 65.
    last_time = time.time()
    begin_time = last_time

    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()


def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
GitHub 加速计划 / on / onnxruntime
13.76 K
2.79 K
下载
microsoft/onnxruntime: 是一个用于运行各种机器学习模型的开源库。适合对机器学习和深度学习有兴趣的人,特别是在开发和部署机器学习模型时需要处理各种不同框架和算子的人。特点是支持多种机器学习框架和算子,包括 TensorFlow、PyTorch、Caffe 等,具有高性能和广泛的兼容性。
最近提交(Master分支:23 天前 )
6d7235ba ### Description Exposes `SetDeterministicCompute` in Java, added to the C API by #18944. ### Motivation and Context Parity between C and Java APIs. 18 小时前
02e00dc0 ### Description Adds support for constructing an `OrtSession` from a `java.nio.ByteBuffer`. These buffers can be memory mapped from files which means there doesn't need to be copies of the model protobuf held in Java, reducing peak memory usage during session construction. ### Motivation and Context Reduces memory usage on model construction by not requiring as many copies on the Java side. Should help with #19599. 21 小时前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐