BERT代码解读
文章目录
本文框架
1. BERT定义
BERT全称Bidirectional Enoceder Representations from Transformers,即双向的Transformers的Encoder。是谷歌于2018年10月提出的一个语言表示模型(language representation model)。
2. BERT实现
2.1 BERT输入
2.1.1 Token Embeddings(input_ids)
token embedding 层是要将各个词转换成固定维度的向量。在BERT中,每个词会被转换成1024维的向量表示。
"Here is some text to encode"
#转为token
['[CLS]', 'here', 'is', 'some', 'text', 'to', 'en', '##code', '[SEP]']
#将token映射为索引
[101, 2182, 2003, 2070, 3793, 2000, 4372, 16044, 102]
序列:
MSDLVTKFESLIISKYPVSFTKEQSAQAAQWESVLKSGQIQPHLDQLNLVLRDNTFIVSTLYPTSTDVHVFEVALPLIKDLVASSKDVKSTYTTYRHILRWIDYMQNLLEVSSTDKLEINH
1. vocab.txt:
[PAD]
[UNK]
[CLS]
[SEP]
[MASK]
L
A
G
V
E
2. bert_config.josn 模型中参数的配置
{
"attention_probs_dropout_prob": 0.0, #乘法attention时,softmax后dropout概率
"hidden_act": "gelu",#激活函数
"hidden_dropout_prob": 0.0,#隐藏层dropout概率
"hidden_size": 1024,#隐藏单元数
"initializer_range": 0.02,#初始化范围
"intermediate_size": 4096,#升维维度
"max_position_embeddings": 40000,#一个大于seq_length的参数,用于生成position_embedding
"num_attention_heads": 16,#每个隐藏层中的attention head数
"num_hidden_layers": 30,#隐藏层数
"type_vocab_size": 2,#segment_ids类别 [0,1]
"vocab_size": 30#词典中词数
}
3. pytorch_model.bin: 预训练权重
2.1.2. Segment Embeddings(token_type_ids)
用来区别两种句子,预训练除了LM,还需要做判断两个句子先后顺序的分类任务。
1.句子:”[CLS] my dog is cute [SEP] he likes play ##ing [SEP]“ 表示成”0 0 0 0 0 0 1 1 1 1 1“
2. 蛋白质序列:
MSDLVTKFESLIISKYPVSFTKEQSAQAAQWESVLKSGQIQPHLDQLNLVLRDNTFIVSTLYPTSTDVHVFEVALPLIKDLVASSKDVKSTYTTYRHILRWIDYMQNLLEVSSTDKLEINH
2.1.3 Position Embeddings(attention_mask)
指定对哪些词进行self-Attention操作。这和Transformer的Position Embeddings不一样,在Transformer中使用的是公式法, 在bert中是通过训练得到的。加入position embeddings会让BERT理解”I think, therefore I am“中的第一个 “I” 和第二个 “I”应该有着不同的向量表示。
具体代码:
import numpy as np
import torch
from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
from transformers import BertModel
model_name = 'bert-base-chinese'
MODEL_PATH = 'E:/transformer_file/bert-base-chinese/'
# a.通过词典导入分词器
tokenizer = BertTokenizer.from_pretrained(model_name)
# b. 导入配置文件
model_config = BertConfig.from_pretrained(model_name)
# 修改配置
model_config.output_hidden_states = True
model_config.output_attentions = True
# 通过配置和路径导入模型
bert_model = BertModel.from_pretrained(MODEL_PATH, config = model_config)
print(tokenizer.encode('我不喜欢你')) #[101, 2769, 679, 1599, 3614, 872, 102]
sen_code = tokenizer.encode_plus('我不喜欢这世界','我只喜欢你')
print(sen_code)
# {'input_ids': [101, 2769, 679, 1599, 3614, 6821, 686, 4518, 102, 2769, 1372, 1599, 3614, 872, 102],
# 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
# 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
2.2 预训练
2.2.1 获取序列的embeddings
长度为n的输入序列将获得的三种不同的向量表示,分别是:
- Token Embeddings, (1, n, 1024) ,词的向量表示
- Segment Embeddings, (1, n, 1024),辅助BERT区别句子对中的两个句子的向量表示
- Position Embeddings ,(1, n, 1024) ,让BERT学习到输入的顺序属性
# BertEmbeddings core forward code:
def forward(self, input_ids=None, token_type_ids=None,
position_ids=None, inputs_embeds=None):
# ignore some codes here...
# step 1: token embeddings
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) # token embeddings
# step 2: position embeddings
position_embeddings = self.position_embeddings(position_ids)
# step 3: segment embeddings
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings #embeddings的维度为B×S×D (1, n, 1024)
embeddings层的模型结构:
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30, 1024, padding_idx=0)
(position_embeddings): Embedding(40000, 1024)
(token_type_embeddings): Embedding(2, 1024)
(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.0, inplace=False)
)
2.2.2 利用Transformer对序列进行编码
1. BertEncoder模型
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=1024, out_features=1024, bias=True)
(key): Linear(in_features=1024, out_features=1024, bias=True)
(value): Linear(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
# encoder是BertEncoder类
encoder_outputs = self.encoder(
embedding_output, # 序列embedding, B x S x D
attention_mask=extended_attention_mask, # 序列self-attention时使用
head_mask=head_mask, # 序列self-attention时使用
encoder_hidden_states=encoder_hidden_states, # decoder,cross-attention
encoder_attention_mask=encoder_extended_attention_mask, # cross-attention
output_attentions=output_attentions, # 是否输出attention
output_hidden_states=output_hidden_states) # 是否输出每层的hidden state
embedding_output: BertEmbeddings的输出,batch中样本序列的每个token的嵌入。B×S×D
extended_attention_mask: self-attention使用。根据attention_mask做维度广播(B×H×S×S),H是head数量,此时,方便下文做self-attention时作mask,即:softmax前对logits作处理,logits+extended_attention_mask,即:attention_mask取值为1时,extended_attention_mask对应位置的取值为0;否则,attention_mask为0时,extended_attention_mask对应位置的取值为-10000.0 (很小的一个数),这样softmax后,mask很小的值对应的位置概率接近0达到mask的目的。
head_mask: self-attention使用。同样可以基于原始输入head_mask作维度广播,广播前的shape为H or L x H;广播后的shape为:L x B x H x S x S。即每个样本序列中每个token对其他tokens的head attentions 值作mask,head attentions数量为L x H。
encoder_hidden_states:可选,cross-attention使用。即:decoder端做编码时,要传入encoder的隐状态,B x S x D。
encoder_attention_mask: 可选,cross-attention使用。即,decoder端做编码时,encoder的隐状态的attention mask。和extended_attention_mask类似,B x S。
output_attentions: 是否输出attention值,bool。可用于可视化attention scores。
output_hidden_states: 是否输出每层得到的隐向量,bool。
# BertEncoder由30层BertLayer构成
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
# BertEncoder Forward核心代码
def forward(self, hidden_states,
attention_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_attention_mask=None,
output_attentions=False, output_hidden_states=False):
# ignore some codes here...
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer): # 12层BertLayer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# step 1: BertLayer iteration
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions) # BertLayer Forward,核心!!!
hidden_states = layer_outputs[0] # overide for next iteration
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) # 存每层的attentions,可以用于可视化
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
- BertLayer:
上述代码最重要的是循环内的BertLayer迭代过程,其核心代码
def forward(self, hidden_states, attention_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_attention_mask=None,
output_attentions=False,):
# step 1.0: self-attention, attention实例是BertAttention类
self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# step 1.1: 如果是decoder, 就作cross-attention,此时step1.0的输出即为decoder侧的序列的self-attention结果,并作为step1.1的输入;step 1.1的输出为decoder侧的cross-attention结果, crossattention实例也是BertAttention
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
# step 2: intermediate转化,对应原论文中的前馈神经网络FFN
intermediate_output = self.intermediate(attention_output)
# step 3: 做skip-connection
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + outputs
return outputs
3. BertAttention
BertAttention是上述代码中attention实例对应的类,也是transformer进行self-attention的核心类。包括了BertSelfAttention和BertSelfOutput成员。
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, hidden_states, attention_mask=None,
head_mask=None, encoder_hidden_states=None,
encoder_attention_mask=None, output_attentions=False):
# step 1: self-attention, B x S x D
self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions)
# step 2: skip-connection, B x S x D
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
BertSelfAttention: 是self-attention,BertSelfAttention可以被实例化为encoder侧的self-attention,也可以被实例化为decoder侧的self-attention,此时attention_mask是非空的 (类似下三角形式的mask矩阵)。同时,还可以实例化为decoder侧的cross-attention,此时,hidden_states即为decoder侧序列的self-attention结果,同时需要传入encoder侧的encoder_hidden_states和encoder_attention_mask来进行cross-attention。
def forward(self, hidden_states, attention_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_attention_mask=None,
output_attentions=False):
# step 1: mapping Query/Key/Value to sub-space
# step 1.1: query mapping
mixed_query_layer = self.query(hidden_states) # B x S x (H*d)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
# step 1.2: key/value mapping
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states) # B x S x (H*d)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states) # B x S x (H*d)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer) # B x H x S x d
key_layer = self.transpose_for_scores(mixed_key_layer) # B x H x S x d
value_layer = self.transpose_for_scores(mixed_value_layer) # B x H x S x d
# step 2: compute attention scores
# step 2.1: raw attention scores
# B x H x S x d B x H x d x S -> B x H x S x S
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# step 2.2: mask if necessary
if attention_mask is not None:
# Apply the attention mask, B x H x S x S
attention_scores = attention_scores + attention_mask
# step 2.3: Normalize the attention scores to probabilities, B x H x S x S
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
# B x H x S x S B x H x S x d -> B x H x S x d
# step 4: aggregate values by attention probs to form context encodings
context_layer = torch.matmul(attention_probs, value_layer)
# B x S x H x d
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# B x S x D
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
# B x S x D,相当于是多头concat操作
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
具体的计算过程:链接
- BertPooler
BertEncoder的输出O,shape为B×S×D,其中每个样本序列(S维度)的第一个token为[CLS]标识的hidden state,标识为o,即:B×D。则得到序列级别的嵌入表征:pooled-sentence-enocding=tanh(W⋅o),shape为B×D。这个主要用于下游任务的fine-tuning。
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output) ## nn.tanh
return pooled_output
3.我的做法
from time import time
import pandas as pd
import visdom
from tensorboardX import SummaryWriter
import torch
from transformers import BertModel, BertTokenizer,BertConfig
import re
from collections import namedtuple
from typing import Any
import tensorboard as tb
import csv
import codecs
import os
import requests
from tqdm.auto import tqdm
import numpy as np
import gzip
import pickle
# class ModelWrapper(torch.nn.Module):
# """
# Wrapper class for model with dict/list rvalues.
# """
#
# def __init__(self, model: torch.nn.Module) -> None:
# """
# Init call.
# """
# super().__init__()
# self.model = model
#
# def forward(self, input_x: torch.Tensor) -> Any:
# """
# Wrap forward call.
# """
# data = self.model(input_x)
#
# if isinstance(data, dict):
# data_named_tuple = namedtuple("ModelEndpoints", sorted(data.keys())) # type: ignore
# data = data_named_tuple(**data) # type: ignore
#
# elif isinstance(data, list):
# data = tuple(data)
#
# return data
def generate_protbert_features(root_dir):
t0 = time()
# modelUrl = 'https://www.dropbox.com/s/dm3m1o0tsv9terq/pytorch_model.bin?dl=1'
# configUrl = 'https://www.dropbox.com/s/d3yw7v4tvi5f4sk/bert_config.json?dl=1'
# vocabUrl = 'https://www.dropbox.com/s/jvrleji50ql5m5i/vocab.txt?dl=1'
# #
modelFolderPath = root_dir + '/inputs/ProtBert_model/'
#
# modelFolderPath = downloadFolderPath
# modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')
#+
# configFilePath = os.path.join(modelFolderPath, 'config.json')
#
# vocabFilePath = os.path.join(modelFolderPath, 'vocab.txt')
modelFilePath = './pytorch_model.bin'
configFilePath = './config.json'
vocabFilePath = './vocab.txt'
# 导入分词器
tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False) # 包含词汇表的文件
# 导入模型
model_config = BertConfig.from_pretrained(configFilePath)
model = BertModel.from_pretrained(modelFolderPath,config=model_config) # 预训练模型
# print(model)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()
print(model)
def make_aseq(seq):
protAlphabet = 'ACDEFGHIKLMNPQRSTVWYX'
return ' '.join([protAlphabet[x] for x in seq])
# data = ['MSREEVESLIQEVLEVYPEKARKDRNKHLAVNDPAVTQSKKCIISNKKSQPGLMTIRGCAYAGSKGVVWGPIKDMIHISHGPVGCGQYSRAGRRNYYIGTTGVNAFVTMNFTSDFQEKDIVFGGDKKLAKLIDEVETLFPLNKGISVQSECPIGLIGDDIESVSKVKGAELSKTIVPVRCEGFRGVSQSLGHHIANDAVRDWVLGKRDEDTTFASTPYDVAIIGDYNIGGDAWSSRILLEEMGLRCVAQWSGDGSISEIELTPKVKLNLVHCYRSMNYISRHMEEKYGIPWMEYNFFGPTKTIESLRAIAAKFDESIQKKCEEVIAKYKPEWEAVVAKYRPRLEGKRVMLYIGGLRPRHVIGAYEDLGMEVVGTGYEFAHNDDYDRTMKEMGDSTLLYDDVTGYEFEEFVKRIKPDLIGSGIKEKFIFQKMGIPFREMHSWDYSGPYHGFDGFAIFARDMDMTLNNPCWKKLQAPWEASQQVDKIKASYPLFLDQDYKDM',
# 'HLQSTPQNLVSNAPIAETAGIAEPPDDDLQARLNTLKKQ']
sequences = []
labels = []
with open('./test_3.txt', 'r') as f:
while True:
id = f.readline()
seq = f.readline().rstrip("\n").strip(' ')
label = f.readline().rstrip('\n').strip(' ')
sequences.append(seq)
labels.extend(label)
if not id:
break
sequences_Example = [' '.join(list(seq)) for seq in sequences]
sequences_Example = [re.sub(r"[-UZOB]", "X", sequence) for sequence in sequences_Example]
sequences_Example = list(filter(None,sequences_Example)) #过滤掉空的序列并且转为list
labels = np.array(labels)
all_protein_features = []
print(len(sequences_Example))
for i, seq in enumerate(sequences_Example):
ids = tokenizer.batch_encode_plus([seq], add_special_tokens=True,
pad_to_max_length=True) # encode_plus返回1、词的编码,2、
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)
seqments_tensor = torch.tensor(ids["token_type_ids"]).to(device)
print("token_type_ids",seqments_tensor)
print('attention_mask',attention_mask)
#可视化网络结构
# model_wrapper = ModelWrapper(model)
# with SummaryWriter(comment='BERT') as w:
# w.add_graph(model_wrapper,input_ids)
with torch.no_grad():
embedding = model(input_ids=input_ids,token_type_ids=seqments_tensor, attention_mask=attention_mask)[0] # 二维的张量,获取模型最后一层的输出结果
# writer.add_graph(model,input_ids)
embedding = embedding.cpu().numpy() #因为model是在gpu将其放在CPU上调 embeddnig转为数组
print('embedding的形状',embedding.shape)
features = []
# for seq_num in range(len(embedding)):
seq_len = (attention_mask[0] == 1).sum()
seq_emd = embedding[0][1:seq_len - 1]
all_protein_features.append(seq_emd)
# all_protein_features += features
protein_features = np.concatenate(all_protein_features,axis=0)
all_protein_features = pd.DataFrame(protein_features)
all_labels = pd.DataFrame(labels)
all_data = pd.concat((all_protein_features,all_labels),axis=1)
with open("./features_3label.csv", "w") as pf:
all_data.to_csv(pf, index=False, header=False)
# pickle.dump({'ProtBert_features':protein_features},
# gzip.open(root_dir+'/output_data/features_352.pkl.gz',
# 'wb')
# )
print('Total time spent for ProtBERT:', time() - t0)
if __name__ == "__main__":
root_dir = './EGRET'
generate_protbert_features(root_dir)
4. 结果
5. 下一步打算
import datetime
import time
import random
import torch
import numpy as np
import re
from transformers import AdamW,BertTokenizer,BertConfig,BertForSequenceClassification,get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset,DataLoader,RandomSampler
from torch.nn.utils import clip_grad_norm
random.seed(20)
np.random.seed(20)
torch.manual_seed(20)
torch.cuda.manual_seed(20)
def read_sequence():
sequence = []
labels = []
with open('./training_fasta_total.txt', 'r') as f:
while True:
id = f.readline()
seq = f.readline().rstrip("\n").strip(' ')
label = f.readline().rstrip('\n').strip(' ')
sequence.append(seq)
labels.extend(label)
if not id:
break
all_sequences = [' '.join(list(seq)) for seq in sequence]
all_sequences = [re.sub(r"[-UZOB]", "X", sequence) for sequence in all_sequences]
all_sequences = list(filter(None,all_sequences)) #过滤掉空的序列并且转为list
labels = np.array(labels).astype(float)
return all_sequences,labels
def define_finetuning_model (root_dir):
all_sequences ,labels = read_sequence()
device = ('cuda' if torch.cuda.is_available() else 'cpu')
modelFolderPath = root_dir + '/inputs/ProtBert_model/'
configFilePath = './config.json'
vocabFilePath = './vocab.txt'
tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False) # 包含词汇表的文件
model_config = BertConfig.from_pretrained(configFilePath)
inputs_id = []
attentions_mask = []
segments_tensor = []
for i, seq in enumerate(all_sequences):
ids = tokenizer.batch_encode_plus([seq],
add_special_tokens=True,
pad_to_max_length=True,
return_tensors='pt') # encode_plus返回1、词的编码,2、
input_id = torch.tensor(ids['input_ids'])
attention_mask = torch.tensor(ids['attention_mask'])
seqment_tensor = torch.tensor(ids["token_type_ids"])
inputs_id.append(input_id)
attentions_mask.append(attention_mask)
segments_tensor.append(seqment_tensor)
all_input_id = torch.cat(inputs_id,dim=0)
all_attention_mask = torch.cat(attentions_mask,dim=0)
all_segment_tensor = torch.cat(segments_tensor,dim=0)
all_label = torch.tensor(labels)
print('第一个序列:', all_sequences[0])
print('第一个序列的编码:', all_input_id[0])
print(all_input_id.shape,all_attention_mask.shape,all_label.shape)
train_data = TensorDataset(all_input_id,all_label)
train_dataloader = DataLoader(train_data,batch_size=32, shuffle=True, num_workers=5)
model = BertForSequenceClassification.from_pretrained(modelFolderPath,
num_ladels=2,
output_attentions = False,
out_hidden_states = False)
model.cuda()
return train_dataloader,model
def train_model(root_dir):
total_train_loss = 0
t0 = time.time()
epoch = 4 # 设置训练次数
device = ('cuda' if torch.cuda.is_available() else 'cpu')
train_dataloader,model =define_finetuning_model(root_dir)
optimizer = AdamW(model.parameters(),
lr= 2e-5,
eps = 1e-8)
total_steps = len(train_dataloader)*epoch
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps= 0,
num_training_steps=total_steps)
for epoch in range(epoch):
model.train()
for index,batch in enumerate(train_dataloader):
b_input_id = batch[0].to(device)
# b_input_mask = batch[1].to(device)
# b_input_segment = batch[2].to(device)
b_input_label = batch[3].to(device)
model.zero_grad()
loss,logist = model(b_input_id,
labels = b_input_label)
total_train_loss += loss
loss.backward()
clip_grad_norm(model.parameters(),1.0)
optimizer.step()
scheduler.step()
aver_train_loss = total_train_loss/len(train_dataloader)
training_time = format(time.time()-t0)
print("平均的训练损失是{}".format(aver_train_loss))
print("训练的总时间是{}".format(training_time))
if __name__== "__main__":
root_dir = './EGRET'
train_model(root_dir)
参考博客:
http://xtf615.com/2020/07/05/transformers/
更多推荐
所有评论(0)