需求说明

  1. 训练过程在 “HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务” 中已经描述过。
  2. 训练结束后,会生成如下的 checkpoints 文件:
    在这里插入图片描述
  3. 现在想用 checkpoint-500 中保存的模型进行预测,看它在测试集上的效果怎么样,即损失值是多少。

需求解决

关键代码

from dataset import GazeCaptureDataset
from transformers import TrainingArguments
from transformers import DeiTForImageClassification
from torch import nn
from transformers import Trainer
from transformers import DeiTConfig

# 数据集根路径
root_path = r"D:\datasets\GazeCapture_new"
# 1. 定义 Dataset
test_dataset = GazeCaptureDataset(root_path, data_type='test')

# 2. 定义 DeiT 图像模型
configuration = DeiTConfig(num_labels=2, problem_type="regression")
model = DeiTForImageClassification(configuration).from_pretrained('gaze_trainer/checkpoint-500')

# 3. 测试
## 3.1 定义测试参数
testing_args = TrainingArguments(output_dir="pred_trainer")

## 3.2 自定义 Trainer
class CustomTester(Trainer):
    # 重写计算 loss 的函数
    def compute_loss(self, model, inputs, return_outputs=False):
        # 获取标签值
        labels = inputs.get("labels")
        # 获取输入值
        x = inputs.get("pixel_values")
        # 模型输出值
        outputs = model(x)
        logits = outputs.get('logits')
        # 定义损失函数为平滑 L1 损失
        loss_fct = nn.SmoothL1Loss()
        # 计算输出值和标签的损失
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

## 3.3 定义 Trainer 对象
tester = CustomTester(
    model=model,
    args=testing_args,
)

## 3.4 调用 predict 方法,开始测试
output = tester.predict(test_dataset=test_dataset)

# 4. 测试结果
print(output)

Dataset

dataset.py 代码如下:

import os.path

from torch.utils.data import Dataset
from transform import transform
import numpy as np

# 读取数据,如果是训练数据,随即打乱数据顺序
def get_label_list(label_path):
    # 存储所有标签文件中的所有内容
    full_lines = []
    # 获取所有标签文件的名称,如 00002.label, 00003.label, ......
    label_names = os.listdir(label_path)
    # 遍历每一个标签文件,并读取其中内容
    for label_name in label_names:
        # 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.label
        label_abs_path = os.path.join(label_path, label_name)
        # 读取每一个标签文件中的内容
        with open(label_abs_path) as flist:
            # 存储该标签文件中的所有内容
            full_line = []
            for line in flist:
                full_line.append(line.strip())
            # 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device'
            full_line.pop(0)
            full_lines.extend(full_line)
    return full_lines


class GazeCaptureDataset(Dataset):
    def __init__(self, root_path, data_type):
        self.data_dir = root_path
        # 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\train
        label_root_path = os.path.join(root_path + '/Label', data_type)
        # 获取所有标签文件中的所有内容
        self.full_lines = get_label_list(label_root_path)
        # 每一行内容的分隔符
        self.delimiter = ' '
        # 数据集长度,也就是一共有多少个图片
        self.num_samples = len(self.full_lines)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 标签文件的一行,对应一个训练实例
        line = self.full_lines[idx]
        # 将标签文件中的一行内容按照分隔符进行分割
        Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter)
        # 获取网络的输入:人脸图片
        face_path = os.path.join(self.data_dir + '/Image/', Face)
        # 读取人脸图像
        with open(face_path, 'rb') as f:
            img = f.read()
        # 将人脸图像进行格式转化:缩放、裁剪、标准化
        pixel_values = transform(img)
        # 获取标签值
        labels = np.array(XYcam.split(","), np.float32)
        # 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx}
        result = {"labels": labels}
        result["pixel_values"] = pixel_values
        return result

输出结果如下:

***** Running Prediction *****
  Num examples = 1716
  Batch size = 8
100%|██████████| 215/215 [01:52<00:00,  1.90it/s]
PredictionOutput(predictions=array([[-2.309026 , -2.752627 ],
       [-2.0178156, -3.0546618],
       [-1.8222798, -3.309564 ],
       ...,
       [-2.6463585, -2.3462727],
       [-2.2149038, -2.7406967],
       [-1.7267275, -3.3450181]], dtype=float32), label_ids=array([[ 0.969375, -7.525975],
       [ 0.969375, -7.525975],
       [ 0.969375, -7.525975],
       ...,
       [ 5.5845  ,  1.93875 ],
       [ 5.5845  ,  1.93875 ],
       [ 5.5845  ,  1.93875 ]], dtype=float32), metrics={'test_loss': 2.8067691326141357, 'test_runtime': 118.2811, 'test_samples_per_second': 14.508, 'test_steps_per_second': 1.818})

可以看到该模型在测试集的损失值是 2.8067691326141357

GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:2 个月前 )
33868a05 * [i18n-HI] Translated accelerate page to Hindi * Update docs/source/hi/accelerate.md Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> * Update docs/source/hi/accelerate.md Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> * Update docs/source/hi/accelerate.md Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> * Update docs/source/hi/accelerate.md Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> --------- Co-authored-by: Kay <kay@Kays-MacBook-Pro.local> Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> 7 天前
e2ac16b2 * rework converter * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * cleaning * cleaning * finalize imports * imports * Update modular_model_converter.py * Better renaming to avoid visiting same file multiple times * start converting files * style * address most comments * style * remove unused stuff in get_needed_imports * style * move class dependency functions outside class * Move main functions outside class * style * Update modular_model_converter.py * rename func * add augmented dependencies * Update modular_model_converter.py * Add types_to_file_type + tweak annotation handling * Allow assignment dependency mapping + fix regex * style + update modular examples * fix modular_roberta example (wrong redefinition of __init__) * slightly correct order in which dependencies will appear * style * review comments * Performance + better handling of dependencies when they are imported * style * Add advanced new classes capabilities * style * add forgotten check * Update modeling_llava_next_video.py * Add prority list ordering in check_conversion as well * Update check_modular_conversion.py * Update configuration_gemma.py 7 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐