【实战分享】用PyTorch实现文本分类器:从数据准备到模型部署
·
【实战分享】用PyTorch实现文本分类器:从数据准备到模型部署
引言
文本分类是自然语言处理领域的基础任务之一,广泛应用于情感分析、垃圾邮件检测、新闻分类等场景。本文将详细介绍如何使用PyTorch实现一个完整的文本分类系统,包括数据准备、模型构建、训练优化和部署上线的全过程。
一、任务定义与数据集选择
1.1 任务定义
文本分类的目标是将一段文本映射到预定义的类别中。例如:
- 情感分析:判断文本是正面、负面还是中性
- 主题分类:将新闻归类到政治、体育、科技等类别
- 意图识别:识别用户的意图(查询、投诉、建议等)
1.2 数据集选择
本文使用IMDB电影评论数据集进行情感分析任务:
- 数据集规模:50000条电影评论
- 训练集:25000条(正面12500,负面12500)
- 测试集:25000条(正面12500,负面12500)
- 数据格式:每条评论是一段英文文本,标签为0(负面)或1(正面)
二、数据预处理
2.1 数据加载
import pandas as pd
from sklearn.model_selection import train_test_split
# 加载数据集
df = pd.read_csv('IMDB_Dataset.csv')
# 查看数据结构
print(df.head())
print(f"数据集大小: {len(df)}")
print(f"类别分布: {df['sentiment'].value_counts()}")
# 分割训练集和测试集
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
2.2 文本清洗
文本数据通常需要进行清洗处理:
import re
import string
def clean_text(text):
"""文本清洗函数"""
# 转换为小写
text = text.lower()
# 移除HTML标签
text = re.sub(r'<.*?>', '', text)
# 移除标点符号
text = text.translate(str.maketrans('', '', string.punctuation))
# 移除数字
text = re.sub(r'\d+', '', text)
# 移除多余空格
text = ' '.join(text.split())
return text
# 应用清洗函数
train_df['cleaned_text'] = train_df['review'].apply(clean_text)
test_df['cleaned_text'] = test_df['review'].apply(clean_text)
2.3 分词与向量化
from torchtext.data import Field, TabularDataset, BucketIterator
# 定义Field
TEXT = Field(tokenize='spacy', tokenizer_language='en_core_web_sm', lower=True)
LABEL = Field(sequential=False, use_vocab=False)
# 定义数据集
fields = [('review', TEXT), ('sentiment', LABEL)]
train_data, test_data = TabularDataset.splits(
path='.',
train='train.csv',
test='test.csv',
format='csv',
fields=fields,
skip_header=True
)
# 构建词汇表
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
# 创建迭代器
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, test_iterator = BucketIterator.splits(
(train_data, test_data),
batch_size=BATCH_SIZE,
device=device,
sort_key=lambda x: len(x.review),
sort_within_batch=False
)
2.4 数据增强(可选)
对于文本分类任务,可以使用以下数据增强技术:
import random
def augment_text(text):
"""文本增强函数"""
words = text.split()
# 随机交换相邻词
if len(words) > 1:
idx = random.randint(0, len(words)-2)
words[idx], words[idx+1] = words[idx+1], words[idx]
# 随机删除词(10%概率)
words = [word for word in words if random.random() > 0.1]
return ' '.join(words)
# 对训练数据进行增强
train_df['augmented_text'] = train_df['cleaned_text'].apply(augment_text)
三、模型架构设计
3.1 简单的LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers,
bidirectional, dropout, pad_idx):
super().__init__()
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
# LSTM层
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True
)
# 全连接层
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
# Dropout层
self.dropout = nn.Dropout(dropout)
def forward(self, text):
"""
Args:
text: shape (batch_size, seq_len)
Returns:
output: shape (batch_size, output_dim)
"""
# 词嵌入
embedded = self.dropout(self.embedding(text)) # (batch_size, seq_len, embedding_dim)
# LSTM输出
output, (hidden, cell) = self.lstm(embedded)
# 取最后一个时间步的隐藏状态
if self.lstm.bidirectional:
hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
else:
hidden = self.dropout(hidden[-1, :, :])
# 全连接层
output = self.fc(hidden)
return output
3.2 模型初始化
# 超参数设置
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
# 创建模型
model = LSTMClassifier(
INPUT_DIM,
EMBEDDING_DIM,
HIDDEN_DIM,
OUTPUT_DIM,
N_LAYERS,
BIDIRECTIONAL,
DROPOUT,
PAD_IDX
)
# 加载预训练词向量
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)
# 将padding和unknown token的向量初始化为零
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
3.3 Transformer-based模型
对于更复杂的任务,可以使用Transformer架构:
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_heads, n_layers, output_dim, dropout, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
self.pos_encoder = PositionalEncoding(embedding_dim, dropout)
encoder_layers = nn.TransformerEncoderLayer(embedding_dim, n_heads, dim_feedforward=512, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
self.fc = nn.Linear(embedding_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
embedded = self.dropout(self.embedding(text))
embedded = self.pos_encoder(embedded)
# Transformer期望的输入形状是(seq_len, batch_size, dim)
embedded = embedded.permute(1, 0, 2)
output = self.transformer_encoder(embedded)
# 取第一个token的输出作为句子表示
cls_output = output[0, :, :]
return self.fc(cls_output)
四、训练与优化
4.1 损失函数与优化器
import torch.optim as optim
# 损失函数
criterion = nn.BCEWithLogitsLoss()
# 优化器
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 将模型和损失函数移动到设备上
model = model.to(device)
criterion = criterion.to(device)
4.2 训练循环
def train(model, iterator, optimizer, criterion):
"""训练函数"""
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in iterator:
optimizer.zero_grad()
text = batch.review
labels = batch.sentiment.float()
predictions = model(text).squeeze(1)
loss = criterion(predictions, labels)
# 计算准确率
rounded_preds = torch.round(torch.sigmoid(predictions))
correct = (rounded_preds == labels).float()
acc = correct.sum() / len(correct)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
4.3 验证函数
def evaluate(model, iterator, criterion):
"""评估函数"""
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for batch in iterator:
text = batch.review
labels = batch.sentiment.float()
predictions = model(text).squeeze(1)
loss = criterion(predictions, labels)
rounded_preds = torch.round(torch.sigmoid(predictions))
correct = (rounded_preds == labels).float()
acc = correct.sum() / len(correct)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
4.4 训练过程
import time
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
# 训练循环
N_EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, test_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
# 保存最佳模型
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'text_classifier.pt')
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
4.5 学习率调度
# 添加学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
# 在训练循环中更新学习率
for epoch in range(N_EPOCHS):
# ... 训练代码 ...
scheduler.step(valid_loss)
五、模型评估与分析
5.1 测试集评估
# 加载最佳模型
model.load_state_dict(torch.load('text_classifier.pt'))
# 在测试集上评估
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
5.2 混淆矩阵
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
def get_predictions(model, iterator):
"""获取模型预测结果"""
model.eval()
predictions = []
labels = []
with torch.no_grad():
for batch in iterator:
text = batch.review
label = batch.sentiment
prediction = model(text).squeeze(1)
prediction = torch.round(torch.sigmoid(prediction))
predictions.extend(prediction.cpu().numpy())
labels.extend(label.cpu().numpy())
return predictions, labels
# 获取预测结果
predictions, labels = get_predictions(model, test_iterator)
# 绘制混淆矩阵
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Negative', 'Positive'],
yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
5.3 错误分析
def analyze_errors(model, iterator, n_examples=5):
"""分析错误样本"""
model.eval()
errors = []
with torch.no_grad():
for batch in iterator:
text = batch.review
labels = batch.sentiment
predictions = model(text).squeeze(1)
predictions = torch.round(torch.sigmoid(predictions))
# 找出错误预测的样本
for i in range(len(labels)):
if predictions[i] != labels[i]:
errors.append({
'text': ' '.join([TEXT.vocab.itos[idx] for idx in text[:, i]]),
'predicted': int(predictions[i].item()),
'actual': int(labels[i].item())
})
if len(errors) >= n_examples:
return errors
return errors
# 分析错误样本
errors = analyze_errors(model, test_iterator)
for error in errors:
print(f"Text: {error['text'][:100]}...")
print(f"Predicted: {'Positive' if error['predicted'] == 1 else 'Negative'}")
print(f"Actual: {'Positive' if error['actual'] == 1 else 'Negative'}")
print('---')
六、模型部署
6.1 模型导出
# 保存完整模型
torch.save(model, 'text_classifier_full.pt')
# 或者保存为ONNX格式
dummy_input = torch.randint(0, INPUT_DIM, (1, 50)).to(device)
torch.onnx.export(model, dummy_input, 'text_classifier.onnx')
6.2 构建推理API
from flask import Flask, request, jsonify
import spacy
app = Flask(__name__)
# 加载模型
model = torch.load('text_classifier_full.pt')
model.eval()
# 加载分词器
nlp = spacy.load('en_core_web_sm')
def predict_sentiment(text):
"""预测情感"""
# 文本预处理
text = clean_text(text)
# 分词
tokens = [token.text for token in nlp.tokenizer(text)]
# 转换为索引
indices = [TEXT.vocab.stoi.get(token, UNK_IDX) for token in tokens]
# 添加batch维度
tensor = torch.LongTensor(indices).unsqueeze(0).to(device)
# 预测
with torch.no_grad():
prediction = model(tensor).squeeze(1)
probability = torch.sigmoid(prediction).item()
return {
'sentiment': 'positive' if probability >= 0.5 else 'negative',
'confidence': float(probability)
}
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
text = data.get('text', '')
if not text:
return jsonify({'error': 'No text provided'}), 400
result = predict_sentiment(text)
return jsonify(result)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
6.3 API测试
import requests
# 测试API
response = requests.post(
'http://localhost:5000/predict',
json={'text': 'This movie is absolutely amazing! I loved every minute of it.'}
)
print(response.json())
# 输出: {'sentiment': 'positive', 'confidence': 0.985}
七、性能优化
7.1 模型压缩
# 量化模型
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
# 保存量化后的模型
torch.save(quantized_model, 'text_classifier_quantized.pt')
7.2 推理加速
# 使用TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('text_classifier_scripted.pt')
# 加载并推理
loaded_model = torch.jit.load('text_classifier_scripted.pt')
7.3 批量推理
def batch_predict(texts):
"""批量预测"""
processed_texts = []
max_len = 0
for text in texts:
tokens = [token.text for token in nlp.tokenizer(clean_text(text))]
indices = [TEXT.vocab.stoi.get(token, UNK_IDX) for token in tokens]
processed_texts.append(indices)
max_len = max(max_len, len(indices))
# 补齐到相同长度
padded_texts = [indices + [PAD_IDX] * (max_len - len(indices))
for indices in processed_texts]
tensor = torch.LongTensor(padded_texts).to(device)
with torch.no_grad():
predictions = model(tensor).squeeze(1)
probabilities = torch.sigmoid(predictions).cpu().numpy()
return [
{'sentiment': 'positive' if p >= 0.5 else 'negative', 'confidence': float(p)}
for p in probabilities
]
八、总结与展望
8.1 完成的工作
本文实现了一个完整的文本分类系统,包括:
- 数据预处理:文本清洗、分词、向量化
- 模型构建:LSTM和Transformer两种架构
- 训练优化:完整的训练循环和超参数调优
- 模型评估:混淆矩阵和错误分析
- 模型部署:Flask API和性能优化
8.2 改进方向
可以从以下几个方面进一步改进:
- 使用预训练模型:如BERT、RoBERTa等
- 超参数调优:使用贝叶斯优化
- 集成学习:融合多个模型的预测结果
- 领域自适应:针对特定领域进行微调
8.3 关键经验
- 数据质量至关重要:好的预处理可以显著提升模型性能
- 从简单模型开始:先尝试简单的LSTM,再考虑复杂的Transformer
- 监控训练过程:使用TensorBoard等工具可视化训练过程
- 关注泛化能力:避免过拟合,使用正则化和早停策略
文本分类是NLP的基础任务,掌握这个任务对于理解更复杂的NLP应用至关重要。希望本文的实战经验能对你有所帮助!
#PyTorch #文本分类 #NLP #深度学习
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)