生成式任务输入就是标签

transformers在进入compute_metrics前会有一个判断,源码如下:

# 版本 transformers==4.41.2
# 在trainer.py 的 3842 行
# Metrics!
if (
    self.compute_metrics is not None
    and all_preds is not None
    and all_labels is not None
    and not self.args.batch_eval_metrics
):
    if args.include_inputs_for_metrics:
        metrics = self.compute_metrics(
            EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
        )
    else:
        metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
elif metrics is None:
    metrics = {}

生成式任务如果没有标签字段,即labels那么这里的all_labels is not None就会是false,从而无法进入compute_metrics方法。
此时可以在TrainingArguments中加入一个变量label_names把输入文本作为标签,如下:

training_args = TrainingArguments(
...
label_names=['input_ids'], # 这里假设我的文本输入叫 ‘input_ids’
...
)

这样就可以进入compute_metrics函数了。
此外,若需要将输入的变量传入compute_metrics,可以在TrainingArguments中设置include_inputs_for_metrics=True

GitHub 加速计划 / tra / transformers
72
5
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:7 个月前 )
c9d1e523 * Update installation.md * Update README.md 7 小时前
d253de6d * initial * fix * fix * update * fix * fixes * quantization * attention mask visualizer * multimodal * small changes * fix code samples 8 小时前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐