Hugging(transformers)读取自定义 checkpoint、使用 Trainer 进行测试回归任务
transformers
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
项目地址:https://gitcode.com/gh_mirrors/tra/transformers
免费下载资源
·
需求说明
- 训练过程在 “HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务” 中已经描述过。
- 训练结束后,会生成如下的 checkpoints 文件:
- 现在想用 checkpoint-500 中保存的模型进行预测,看它在测试集上的效果怎么样,即损失值是多少。
需求解决
关键代码
- 调用 Trainer 的 predict 方法,参数传入测试集 Dataset。
关于 predict 的更多用法可以参考官方文档:https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.predict
from dataset import GazeCaptureDataset
from transformers import TrainingArguments
from transformers import DeiTForImageClassification
from torch import nn
from transformers import Trainer
from transformers import DeiTConfig
# 数据集根路径
root_path = r"D:\datasets\GazeCapture_new"
# 1. 定义 Dataset
test_dataset = GazeCaptureDataset(root_path, data_type='test')
# 2. 定义 DeiT 图像模型
configuration = DeiTConfig(num_labels=2, problem_type="regression")
model = DeiTForImageClassification(configuration).from_pretrained('gaze_trainer/checkpoint-500')
# 3. 测试
## 3.1 定义测试参数
testing_args = TrainingArguments(output_dir="pred_trainer")
## 3.2 自定义 Trainer
class CustomTester(Trainer):
# 重写计算 loss 的函数
def compute_loss(self, model, inputs, return_outputs=False):
# 获取标签值
labels = inputs.get("labels")
# 获取输入值
x = inputs.get("pixel_values")
# 模型输出值
outputs = model(x)
logits = outputs.get('logits')
# 定义损失函数为平滑 L1 损失
loss_fct = nn.SmoothL1Loss()
# 计算输出值和标签的损失
loss = loss_fct(logits, labels)
return (loss, outputs) if return_outputs else loss
## 3.3 定义 Trainer 对象
tester = CustomTester(
model=model,
args=testing_args,
)
## 3.4 调用 predict 方法,开始测试
output = tester.predict(test_dataset=test_dataset)
# 4. 测试结果
print(output)
Dataset
dataset.py
代码如下:
import os.path
from torch.utils.data import Dataset
from transform import transform
import numpy as np
# 读取数据,如果是训练数据,随即打乱数据顺序
def get_label_list(label_path):
# 存储所有标签文件中的所有内容
full_lines = []
# 获取所有标签文件的名称,如 00002.label, 00003.label, ......
label_names = os.listdir(label_path)
# 遍历每一个标签文件,并读取其中内容
for label_name in label_names:
# 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.label
label_abs_path = os.path.join(label_path, label_name)
# 读取每一个标签文件中的内容
with open(label_abs_path) as flist:
# 存储该标签文件中的所有内容
full_line = []
for line in flist:
full_line.append(line.strip())
# 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device'
full_line.pop(0)
full_lines.extend(full_line)
return full_lines
class GazeCaptureDataset(Dataset):
def __init__(self, root_path, data_type):
self.data_dir = root_path
# 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\train
label_root_path = os.path.join(root_path + '/Label', data_type)
# 获取所有标签文件中的所有内容
self.full_lines = get_label_list(label_root_path)
# 每一行内容的分隔符
self.delimiter = ' '
# 数据集长度,也就是一共有多少个图片
self.num_samples = len(self.full_lines)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 标签文件的一行,对应一个训练实例
line = self.full_lines[idx]
# 将标签文件中的一行内容按照分隔符进行分割
Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter)
# 获取网络的输入:人脸图片
face_path = os.path.join(self.data_dir + '/Image/', Face)
# 读取人脸图像
with open(face_path, 'rb') as f:
img = f.read()
# 将人脸图像进行格式转化:缩放、裁剪、标准化
pixel_values = transform(img)
# 获取标签值
labels = np.array(XYcam.split(","), np.float32)
# 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx}
result = {"labels": labels}
result["pixel_values"] = pixel_values
return result
输出结果如下:
***** Running Prediction *****
Num examples = 1716
Batch size = 8
100%|██████████| 215/215 [01:52<00:00, 1.90it/s]
PredictionOutput(predictions=array([[-2.309026 , -2.752627 ],
[-2.0178156, -3.0546618],
[-1.8222798, -3.309564 ],
...,
[-2.6463585, -2.3462727],
[-2.2149038, -2.7406967],
[-1.7267275, -3.3450181]], dtype=float32), label_ids=array([[ 0.969375, -7.525975],
[ 0.969375, -7.525975],
[ 0.969375, -7.525975],
...,
[ 5.5845 , 1.93875 ],
[ 5.5845 , 1.93875 ],
[ 5.5845 , 1.93875 ]], dtype=float32), metrics={'test_loss': 2.8067691326141357, 'test_runtime': 118.2811, 'test_samples_per_second': 14.508, 'test_steps_per_second': 1.818})
可以看到该模型在测试集的损失值是 2.8067691326141357
。
GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:2 个月前 )
33868a05
* [i18n-HI] Translated accelerate page to Hindi
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
---------
Co-authored-by: Kay <kay@Kays-MacBook-Pro.local>
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> 7 天前
e2ac16b2
* rework converter
* Update modular_model_converter.py
* Update modular_model_converter.py
* Update modular_model_converter.py
* Update modular_model_converter.py
* cleaning
* cleaning
* finalize imports
* imports
* Update modular_model_converter.py
* Better renaming to avoid visiting same file multiple times
* start converting files
* style
* address most comments
* style
* remove unused stuff in get_needed_imports
* style
* move class dependency functions outside class
* Move main functions outside class
* style
* Update modular_model_converter.py
* rename func
* add augmented dependencies
* Update modular_model_converter.py
* Add types_to_file_type + tweak annotation handling
* Allow assignment dependency mapping + fix regex
* style + update modular examples
* fix modular_roberta example (wrong redefinition of __init__)
* slightly correct order in which dependencies will appear
* style
* review comments
* Performance + better handling of dependencies when they are imported
* style
* Add advanced new classes capabilities
* style
* add forgotten check
* Update modeling_llava_next_video.py
* Add prority list ordering in check_conversion as well
* Update check_modular_conversion.py
* Update configuration_gemma.py 7 天前
更多推荐
已为社区贡献4条内容
所有评论(0)