Hugging Face Transformers 微调--利用 YelpReviewFull 做情感分类
transformers
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
项目地址:https://gitcode.com/gh_mirrors/tra/transformers
免费下载资源
·
1. 代码来源
2. 代码整理
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate
# 1.下载数据集,自动联网从外网下载
dataset = load_dataset("yelp_review_full")
# 2.预处理数据
"""
下载数据集到本地后,使用 Tokenizer 来处理文本,
对于长度不等的输入数据,可以使用填充(padding)和截断(truncation)策略来处理。
Datasets 的 `map` 方法,支持一次性在整个数据集上应用预处理函数。
自定义函数使用填充到最大长度的策略,处理整个数据集:"""
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# 3. 数据抽样
"""
使用 200 个数据样本,
在 BERT 上演示小规模训练(基于 Pytorch Trainer)"""
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(200))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(200))
# 4. 加载 BERT 模型
"""
num_labels参数表示分类任务中的标签类型的数量
警告通知我们正在丢弃一些权重(`vocab_transform` 和 `vocab_layer_norm` 层),并随机初始化其他一些权重(`pre_classifier` 和 `classifier` 层)。
在微调模型情况下是绝对正常的,因为我们正在删除用于预训练模型的掩码语言建模任务的头部,并用一个新的头部替换它,对于这个新头部,我们没有预训练的权重,
所以库会警告我们在用它进行推理之前应该对这个模型进行微调,而这正是我们要做的事情。
"""
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
# 5.训练超参数(TrainingArguments)
"""
logging_steps 默认值为500,根据我们的训练数据和步长,将其设置为 100
指定evaluation_strategy参数,以便在 epoch 结束时报告评估指标"""
model_dir = "fine_model/bert-base-cased-finetune-yelp"
training_args = TrainingArguments(output_dir=model_dir,
per_device_train_batch_size= 4,
num_train_epochs=4,
evaluation_strategy="epoch",
logging_steps=100)
# 6. 训练过程中的指标评估(Evaluate)
"""
[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)*支持使用一行代码,获得数十种不同领域(自然语言处理、计算机视觉、强化学习等)的评估方法。
当前支持 完整评估指标:https://huggingface.co/evaluate-metric"""
metric = evaluate.load("accuracy")
"""
调用 compute 函数来计算预测的准确率。
在将预测传递给 compute 函数之前,我们需要将 logits 转换为预测值(所有Transformers 模型都返回 logits)"""
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
# 7. 实例化训练器,开始训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics,
)
trainer.train()
# 8. 测试数据测试
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(200))
trainer.evaluate(small_test_dataset)
# 9. 保存模型 或者 状态
trainer.save_model(model_dir)
trainer.save_state()
运行结果:
Map: 100%|██████████| 650000/650000 [01:31<00:00, 7067.71 examples/s]
Map: 100%|██████████| 50000/50000 [00:07<00:00, 6877.28 examples/s]
25%|██▌ | 50/200 [02:34<07:27, 2.98s/it]
0%| | 0/25 [00:00<?, ?it/s]
8%|▊ | 2/25 [00:01<00:21, 1.05it/s]
12%|█▏ | 3/25 [00:03<00:30, 1.37s/it]
16%|█▌ | 4/25 [00:05<00:33, 1.60s/it]
20%|██ | 5/25 [00:07<00:35, 1.76s/it]
24%|██▍ | 6/25 [00:09<00:34, 1.80s/it]
28%|██▊ | 7/25 [00:11<00:33, 1.84s/it]
32%|███▏ | 8/25 [00:13<00:31, 1.86s/it]
36%|███▌ | 9/25 [00:15<00:29, 1.87s/it]
40%|████ | 10/25 [00:17<00:28, 1.92s/it]
44%|████▍ | 11/25 [00:19<00:26, 1.92s/it]
48%|████▊ | 12/25 [00:21<00:25, 1.96s/it]
52%|█████▏ | 13/25 [00:23<00:24, 2.01s/it]
56%|█████▌ | 14/25 [00:26<00:23, 2.16s/it]
60%|██████ | 15/25 [00:28<00:20, 2.08s/it]
64%|██████▍ | 16/25 [00:29<00:18, 2.02s/it]
68%|██████▊ | 17/25 [00:31<00:15, 2.00s/it]
72%|███████▏ | 18/25 [00:33<00:13, 1.98s/it]
76%|███████▌ | 19/25 [00:35<00:11, 1.98s/it]
80%|████████ | 20/25 [00:37<00:10, 2.02s/it]
84%|████████▍ | 21/25 [00:39<00:07, 2.00s/it]
88%|████████▊ | 22/25 [00:41<00:05, 1.97s/it]
92%|█████████▏| 23/25 [00:43<00:03, 1.99s/it]
96%|█████████▌| 24/25 [00:45<00:01, 1.98s/it]
25%|██▌ | 50/200 [03:24<07:27, 2.98s/it]
100%|██████████| 25/25 [00:47<00:00, 1.96s/it]
{'eval_loss': 1.5957962274551392, 'eval_accuracy': 0.245, 'eval_runtime': 49.6374, 'eval_samples_per_second': 4.029, 'eval_steps_per_second': 0.504, 'epoch': 1.0}
50%|█████ | 100/200 [05:53<04:58, 2.99s/it]{'loss': 1.5932, 'grad_norm': 8.658624649047852, 'learning_rate': 2.5e-05, 'epoch': 2.0}
0%| | 0/25 [00:00<?, ?it/s]
8%|▊ | 2/25 [00:01<00:20, 1.11it/s]
12%|█▏ | 3/25 [00:03<00:29, 1.35s/it]
16%|█▌ | 4/25 [00:05<00:32, 1.54s/it]
20%|██ | 5/25 [00:07<00:32, 1.64s/it]
24%|██▍ | 6/25 [00:09<00:33, 1.74s/it]
28%|██▊ | 7/25 [00:11<00:32, 1.83s/it]
32%|███▏ | 8/25 [00:13<00:30, 1.82s/it]
36%|███▌ | 9/25 [00:15<00:29, 1.82s/it]
40%|████ | 10/25 [00:16<00:27, 1.83s/it]
44%|████▍ | 11/25 [00:18<00:25, 1.81s/it]
48%|████▊ | 12/25 [00:20<00:23, 1.80s/it]
52%|█████▏ | 13/25 [00:22<00:21, 1.79s/it]
56%|█████▌ | 14/25 [00:23<00:19, 1.78s/it]
60%|██████ | 15/25 [00:25<00:17, 1.78s/it]
64%|██████▍ | 16/25 [00:27<00:15, 1.78s/it]
68%|██████▊ | 17/25 [00:29<00:14, 1.77s/it]
72%|███████▏ | 18/25 [00:31<00:12, 1.77s/it]
76%|███████▌ | 19/25 [00:32<00:10, 1.77s/it]
80%|████████ | 20/25 [00:34<00:08, 1.77s/it]
84%|████████▍ | 21/25 [00:36<00:07, 1.77s/it]
88%|████████▊ | 22/25 [00:38<00:05, 1.77s/it]
92%|█████████▏| 23/25 [00:39<00:03, 1.77s/it]
96%|█████████▌| 24/25 [00:41<00:01, 1.77s/it]
50%|█████ | 100/200 [06:39<04:58, 2.99s/it]
100%|██████████| 25/25 [00:43<00:00, 1.77s/it]
{'eval_loss': 1.5762381553649902, 'eval_accuracy': 0.315, 'eval_runtime': 45.2419, 'eval_samples_per_second': 4.421, 'eval_steps_per_second': 0.553, 'epoch': 2.0}
75%|███████▌ | 150/200 [09:10<02:30, 3.00s/it]
0%| | 0/25 [00:00<?, ?it/s]
8%|▊ | 2/25 [00:02<00:24, 1.06s/it]
12%|█▏ | 3/25 [00:04<00:31, 1.42s/it]
16%|█▌ | 4/25 [00:05<00:33, 1.58s/it]
20%|██ | 5/25 [00:07<00:33, 1.66s/it]
24%|██▍ | 6/25 [00:09<00:32, 1.72s/it]
28%|██▊ | 7/25 [00:11<00:31, 1.76s/it]
32%|███▏ | 8/25 [00:13<00:30, 1.81s/it]
36%|███▌ | 9/25 [00:15<00:28, 1.81s/it]
40%|████ | 10/25 [00:16<00:27, 1.83s/it]
44%|████▍ | 11/25 [00:18<00:25, 1.82s/it]
48%|████▊ | 12/25 [00:20<00:23, 1.85s/it]
52%|█████▏ | 13/25 [00:22<00:22, 1.84s/it]
56%|█████▌ | 14/25 [00:24<00:20, 1.87s/it]
60%|██████ | 15/25 [00:26<00:18, 1.85s/it]
64%|██████▍ | 16/25 [00:28<00:16, 1.83s/it]
68%|██████▊ | 17/25 [00:29<00:14, 1.82s/it]
72%|███████▏ | 18/25 [00:31<00:12, 1.85s/it]
76%|███████▌ | 19/25 [00:33<00:10, 1.83s/it]
80%|████████ | 20/25 [00:35<00:09, 1.83s/it]
84%|████████▍ | 21/25 [00:37<00:07, 1.82s/it]
88%|████████▊ | 22/25 [00:38<00:05, 1.81s/it]
92%|█████████▏| 23/25 [00:40<00:03, 1.80s/it]
96%|█████████▌| 24/25 [00:42<00:01, 1.81s/it]
75%|███████▌ | 150/200 [09:56<02:30, 3.00s/it]
100%|██████████| 25/25 [00:44<00:00, 1.80s/it]
{'eval_loss': 1.2508662939071655, 'eval_accuracy': 0.41, 'eval_runtime': 46.2305, 'eval_samples_per_second': 4.326, 'eval_steps_per_second': 0.541, 'epoch': 3.0}
100%|██████████| 200/200 [12:34<00:00, 3.03s/it]{'loss': 1.1756, 'grad_norm': 16.391265869140625, 'learning_rate': 0.0, 'epoch': 4.0}
0%| | 0/25 [00:00<?, ?it/s]
8%|▊ | 2/25 [00:01<00:20, 1.13it/s]
12%|█▏ | 3/25 [00:03<00:27, 1.25s/it]
16%|█▌ | 4/25 [00:05<00:30, 1.44s/it]
20%|██ | 5/25 [00:07<00:31, 1.56s/it]
24%|██▍ | 6/25 [00:08<00:30, 1.63s/it]
28%|██▊ | 7/25 [00:10<00:30, 1.67s/it]
32%|███▏ | 8/25 [00:12<00:28, 1.70s/it]
36%|███▌ | 9/25 [00:14<00:27, 1.73s/it]
40%|████ | 10/25 [00:15<00:26, 1.74s/it]
44%|████▍ | 11/25 [00:17<00:24, 1.76s/it]
48%|████▊ | 12/25 [00:19<00:22, 1.77s/it]
52%|█████▏ | 13/25 [00:21<00:21, 1.78s/it]
56%|█████▌ | 14/25 [00:23<00:19, 1.78s/it]
60%|██████ | 15/25 [00:24<00:17, 1.78s/it]
64%|██████▍ | 16/25 [00:26<00:15, 1.77s/it]
68%|██████▊ | 17/25 [00:28<00:14, 1.77s/it]
72%|███████▏ | 18/25 [00:30<00:12, 1.77s/it]
76%|███████▌ | 19/25 [00:32<00:10, 1.79s/it]
80%|████████ | 20/25 [00:33<00:08, 1.79s/it]
84%|████████▍ | 21/25 [00:35<00:07, 1.78s/it]
88%|████████▊ | 22/25 [00:37<00:05, 1.78s/it]
92%|█████████▏| 23/25 [00:39<00:03, 1.77s/it]
96%|█████████▌| 24/25 [00:40<00:01, 1.78s/it]
100%|██████████| 200/200 [13:18<00:00, 3.03s/it]
100%|██████████| 25/25 [00:42<00:00, 1.77s/it]
100%|██████████| 200/200 [13:18<00:00, 3.99s/it]
{'eval_loss': 1.197533369064331, 'eval_accuracy': 0.505, 'eval_runtime': 44.4626, 'eval_samples_per_second': 4.498, 'eval_steps_per_second': 0.562, 'epoch': 4.0}
{'train_runtime': 798.4673, 'train_samples_per_second': 1.002, 'train_steps_per_second': 0.25, 'train_loss': 1.3843913650512696, 'epoch': 4.0}
100%|██████████| 25/25 [00:42<00:00, 1.72s/it]
Process finished with exit code 0
3. 解释
3.1 训练超参数
print(training_args)
结果:
...
per_device_eval_batch_size=8,
per_device_train_batch_size=4,
...
训练时 batch_size=4,200 个样本,一 epoch 有 50个 step ,训练4 epoch ,共有 200 个 step
评估时 batch_size=8,200 个样本,一 epoch 有 25个 step
并且每个epoch 评估一次,也就是 50个 step 评估一次
3.2 trainer.save_state
其中 trainer.save_state() 主函数 :
def save_state(self):
"""
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
Under distributed environment this is done only for a process with rank 0.
"""
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, "trainer_state.json")
self.state.save_to_json(path)
放在输出路径之下,不需要参数
3.3 模型保存结果
fine_model$ tree
.
└── bert-base-cased-finetune-yelp
├── config.json
├── model.safetensors
├── runs
│ └── ...
├── trainer_state.json
└── training_args.bin
3 directories, 6 files
其中 config.json:
{
"_name_or_path": "bert-base-cased",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2",
"3": "LABEL_3",
"4": "LABEL_4"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2,
"LABEL_3": 3,
"LABEL_4": 4
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"problem_type": "single_label_classification",
"torch_dtype": "float32",
"transformers_version": "4.38.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 28996
}
trainer_state.json :
{
"best_metric": null,
"best_model_checkpoint": null,
"epoch": 4.0,
"eval_steps": 500,
"global_step": 200,
"is_hyper_param_search": false,
"is_local_process_zero": true,
"is_world_process_zero": true,
"log_history": [
{
"epoch": 1.0,
"eval_accuracy": 0.245,
"eval_loss": 1.5957962274551392,
"eval_runtime": 49.6374,
"eval_samples_per_second": 4.029,
"eval_steps_per_second": 0.504,
"step": 50
},
{
"epoch": 2.0,
"grad_norm": 8.658624649047852,
"learning_rate": 2.5e-05,
"loss": 1.5932,
"step": 100
},
{
"epoch": 2.0,
"eval_accuracy": 0.315,
"eval_loss": 1.5762381553649902,
"eval_runtime": 45.2419,
"eval_samples_per_second": 4.421,
"eval_steps_per_second": 0.553,
"step": 100
},
{
"epoch": 3.0,
"eval_accuracy": 0.41,
"eval_loss": 1.2508662939071655,
"eval_runtime": 46.2305,
"eval_samples_per_second": 4.326,
"eval_steps_per_second": 0.541,
"step": 150
},
{
"epoch": 4.0,
"grad_norm": 16.391265869140625,
"learning_rate": 0.0,
"loss": 1.1756,
"step": 200
},
{
"epoch": 4.0,
"eval_accuracy": 0.505,
"eval_loss": 1.197533369064331,
"eval_runtime": 44.4626,
"eval_samples_per_second": 4.498,
"eval_steps_per_second": 0.562,
"step": 200
},
{
"epoch": 4.0,
"step": 200,
"total_flos": 210494513971200.0,
"train_loss": 1.3843913650512696,
"train_runtime": 798.4673,
"train_samples_per_second": 1.002,
"train_steps_per_second": 0.25
},
{
"epoch": 4.0,
"eval_accuracy": 0.465,
"eval_loss": 1.1850619316101074,
"eval_runtime": 44.7757,
"eval_samples_per_second": 4.467,
"eval_steps_per_second": 0.558,
"step": 200
}
],
"logging_steps": 100,
"max_steps": 200,
"num_input_tokens_seen": 0,
"num_train_epochs": 4,
"save_steps": 500,
"total_flos": 210494513971200.0,
"train_batch_size": 4,
"trial_name": null,
"trial_params": null
}
GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:2 个月前 )
33868a05
* [i18n-HI] Translated accelerate page to Hindi
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
* Update docs/source/hi/accelerate.md
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com>
---------
Co-authored-by: Kay <kay@Kays-MacBook-Pro.local>
Co-authored-by: K.B.Dharun Krishna <kbdharunkrishna@gmail.com> 3 小时前
e2ac16b2
* rework converter
* Update modular_model_converter.py
* Update modular_model_converter.py
* Update modular_model_converter.py
* Update modular_model_converter.py
* cleaning
* cleaning
* finalize imports
* imports
* Update modular_model_converter.py
* Better renaming to avoid visiting same file multiple times
* start converting files
* style
* address most comments
* style
* remove unused stuff in get_needed_imports
* style
* move class dependency functions outside class
* Move main functions outside class
* style
* Update modular_model_converter.py
* rename func
* add augmented dependencies
* Update modular_model_converter.py
* Add types_to_file_type + tweak annotation handling
* Allow assignment dependency mapping + fix regex
* style + update modular examples
* fix modular_roberta example (wrong redefinition of __init__)
* slightly correct order in which dependencies will appear
* style
* review comments
* Performance + better handling of dependencies when they are imported
* style
* Add advanced new classes capabilities
* style
* add forgotten check
* Update modeling_llava_next_video.py
* Add prority list ordering in check_conversion as well
* Update check_modular_conversion.py
* Update configuration_gemma.py 9 小时前
更多推荐
已为社区贡献9条内容
所有评论(0)