生成式任务输入就是标签

transformers在进入compute_metrics前会有一个判断,源码如下:

# 版本 transformers==4.41.2
# 在trainer.py 的 3842 行
# Metrics!
if (
    self.compute_metrics is not None
    and all_preds is not None
    and all_labels is not None
    and not self.args.batch_eval_metrics
):
    if args.include_inputs_for_metrics:
        metrics = self.compute_metrics(
            EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
        )
    else:
        metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
elif metrics is None:
    metrics = {}

生成式任务如果没有标签字段,即labels那么这里的all_labels is not None就会是false,从而无法进入compute_metrics方法。
此时可以在TrainingArguments中加入一个变量label_names把输入文本作为标签,如下:

training_args = TrainingArguments(
...
label_names=['input_ids'], # 这里假设我的文本输入叫 ‘input_ids’
...
)

这样就可以进入compute_metrics函数了。
此外,若需要将输入的变量传入compute_metrics,可以在TrainingArguments中设置include_inputs_for_metrics=True

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐