【实战分享】用PyTorch实现文本分类器:从数据准备到模型部署

引言

文本分类是自然语言处理领域的基础任务之一,广泛应用于情感分析、垃圾邮件检测、新闻分类等场景。本文将详细介绍如何使用PyTorch实现一个完整的文本分类系统,包括数据准备、模型构建、训练优化和部署上线的全过程。

一、任务定义与数据集选择

1.1 任务定义

文本分类的目标是将一段文本映射到预定义的类别中。例如:

  • 情感分析:判断文本是正面、负面还是中性
  • 主题分类:将新闻归类到政治、体育、科技等类别
  • 意图识别:识别用户的意图(查询、投诉、建议等)

1.2 数据集选择

本文使用IMDB电影评论数据集进行情感分析任务:

  • 数据集规模:50000条电影评论
  • 训练集:25000条(正面12500,负面12500)
  • 测试集:25000条(正面12500,负面12500)
  • 数据格式:每条评论是一段英文文本,标签为0(负面)或1(正面)

二、数据预处理

2.1 数据加载

import pandas as pd
from sklearn.model_selection import train_test_split

# 加载数据集
df = pd.read_csv('IMDB_Dataset.csv')

# 查看数据结构
print(df.head())
print(f"数据集大小: {len(df)}")
print(f"类别分布: {df['sentiment'].value_counts()}")

# 分割训练集和测试集
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

2.2 文本清洗

文本数据通常需要进行清洗处理:

import re
import string

def clean_text(text):
    """文本清洗函数"""
    # 转换为小写
    text = text.lower()
    
    # 移除HTML标签
    text = re.sub(r'<.*?>', '', text)
    
    # 移除标点符号
    text = text.translate(str.maketrans('', '', string.punctuation))
    
    # 移除数字
    text = re.sub(r'\d+', '', text)
    
    # 移除多余空格
    text = ' '.join(text.split())
    
    return text

# 应用清洗函数
train_df['cleaned_text'] = train_df['review'].apply(clean_text)
test_df['cleaned_text'] = test_df['review'].apply(clean_text)

2.3 分词与向量化

from torchtext.data import Field, TabularDataset, BucketIterator

# 定义Field
TEXT = Field(tokenize='spacy', tokenizer_language='en_core_web_sm', lower=True)
LABEL = Field(sequential=False, use_vocab=False)

# 定义数据集
fields = [('review', TEXT), ('sentiment', LABEL)]
train_data, test_data = TabularDataset.splits(
    path='.',
    train='train.csv',
    test='test.csv',
    format='csv',
    fields=fields,
    skip_header=True
)

# 构建词汇表
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')

# 创建迭代器
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, test_iterator = BucketIterator.splits(
    (train_data, test_data),
    batch_size=BATCH_SIZE,
    device=device,
    sort_key=lambda x: len(x.review),
    sort_within_batch=False
)

2.4 数据增强(可选)

对于文本分类任务,可以使用以下数据增强技术:

import random

def augment_text(text):
    """文本增强函数"""
    words = text.split()
    
    # 随机交换相邻词
    if len(words) > 1:
        idx = random.randint(0, len(words)-2)
        words[idx], words[idx+1] = words[idx+1], words[idx]
    
    # 随机删除词(10%概率)
    words = [word for word in words if random.random() > 0.1]
    
    return ' '.join(words)

# 对训练数据进行增强
train_df['augmented_text'] = train_df['cleaned_text'].apply(augment_text)

三、模型架构设计

3.1 简单的LSTM模型

import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, 
                 bidirectional, dropout, pad_idx):
        super().__init__()
        
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        
        # LSTM层
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=n_layers,
            bidirectional=bidirectional,
            dropout=dropout,
            batch_first=True
        )
        
        # 全连接层
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        # Dropout层
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, text):
        """
        Args:
            text: shape (batch_size, seq_len)
        
        Returns:
            output: shape (batch_size, output_dim)
        """
        # 词嵌入
        embedded = self.dropout(self.embedding(text))  # (batch_size, seq_len, embedding_dim)
        
        # LSTM输出
        output, (hidden, cell) = self.lstm(embedded)
        
        # 取最后一个时间步的隐藏状态
        if self.lstm.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        else:
            hidden = self.dropout(hidden[-1, :, :])
        
        # 全连接层
        output = self.fc(hidden)
        
        return output

3.2 模型初始化

# 超参数设置
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

# 创建模型
model = LSTMClassifier(
    INPUT_DIM,
    EMBEDDING_DIM,
    HIDDEN_DIM,
    OUTPUT_DIM,
    N_LAYERS,
    BIDIRECTIONAL,
    DROPOUT,
    PAD_IDX
)

# 加载预训练词向量
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

# 将padding和unknown token的向量初始化为零
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")

3.3 Transformer-based模型

对于更复杂的任务,可以使用Transformer架构:

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_heads, n_layers, output_dim, dropout, pad_idx):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.pos_encoder = PositionalEncoding(embedding_dim, dropout)
        
        encoder_layers = nn.TransformerEncoderLayer(embedding_dim, n_heads, dim_feedforward=512, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
        
        self.fc = nn.Linear(embedding_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, text):
        embedded = self.dropout(self.embedding(text))
        embedded = self.pos_encoder(embedded)
        
        # Transformer期望的输入形状是(seq_len, batch_size, dim)
        embedded = embedded.permute(1, 0, 2)
        
        output = self.transformer_encoder(embedded)
        
        # 取第一个token的输出作为句子表示
        cls_output = output[0, :, :]
        
        return self.fc(cls_output)

四、训练与优化

4.1 损失函数与优化器

import torch.optim as optim

# 损失函数
criterion = nn.BCEWithLogitsLoss()

# 优化器
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 将模型和损失函数移动到设备上
model = model.to(device)
criterion = criterion.to(device)

4.2 训练循环

def train(model, iterator, optimizer, criterion):
    """训练函数"""
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        optimizer.zero_grad()
        
        text = batch.review
        labels = batch.sentiment.float()
        
        predictions = model(text).squeeze(1)
        
        loss = criterion(predictions, labels)
        
        # 计算准确率
        rounded_preds = torch.round(torch.sigmoid(predictions))
        correct = (rounded_preds == labels).float()
        acc = correct.sum() / len(correct)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

4.3 验证函数

def evaluate(model, iterator, criterion):
    """评估函数"""
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
        for batch in iterator:
            text = batch.review
            labels = batch.sentiment.float()
            
            predictions = model(text).squeeze(1)
            
            loss = criterion(predictions, labels)
            
            rounded_preds = torch.round(torch.sigmoid(predictions))
            correct = (rounded_preds == labels).float()
            acc = correct.sum() / len(correct)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

4.4 训练过程

import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# 训练循环
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, test_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    # 保存最佳模型
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'text_classifier.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

4.5 学习率调度

# 添加学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)

# 在训练循环中更新学习率
for epoch in range(N_EPOCHS):
    # ... 训练代码 ...
    
    scheduler.step(valid_loss)

五、模型评估与分析

5.1 测试集评估

# 加载最佳模型
model.load_state_dict(torch.load('text_classifier.pt'))

# 在测试集上评估
test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

5.2 混淆矩阵

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

def get_predictions(model, iterator):
    """获取模型预测结果"""
    model.eval()
    
    predictions = []
    labels = []
    
    with torch.no_grad():
        for batch in iterator:
            text = batch.review
            label = batch.sentiment
            
            prediction = model(text).squeeze(1)
            prediction = torch.round(torch.sigmoid(prediction))
            
            predictions.extend(prediction.cpu().numpy())
            labels.extend(label.cpu().numpy())
    
    return predictions, labels

# 获取预测结果
predictions, labels = get_predictions(model, test_iterator)

# 绘制混淆矩阵
cm = confusion_matrix(labels, predictions)

plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Negative', 'Positive'],
            yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

5.3 错误分析

def analyze_errors(model, iterator, n_examples=5):
    """分析错误样本"""
    model.eval()
    
    errors = []
    
    with torch.no_grad():
        for batch in iterator:
            text = batch.review
            labels = batch.sentiment
            
            predictions = model(text).squeeze(1)
            predictions = torch.round(torch.sigmoid(predictions))
            
            # 找出错误预测的样本
            for i in range(len(labels)):
                if predictions[i] != labels[i]:
                    errors.append({
                        'text': ' '.join([TEXT.vocab.itos[idx] for idx in text[:, i]]),
                        'predicted': int(predictions[i].item()),
                        'actual': int(labels[i].item())
                    })
                
                if len(errors) >= n_examples:
                    return errors
    
    return errors

# 分析错误样本
errors = analyze_errors(model, test_iterator)

for error in errors:
    print(f"Text: {error['text'][:100]}...")
    print(f"Predicted: {'Positive' if error['predicted'] == 1 else 'Negative'}")
    print(f"Actual: {'Positive' if error['actual'] == 1 else 'Negative'}")
    print('---')

六、模型部署

6.1 模型导出

# 保存完整模型
torch.save(model, 'text_classifier_full.pt')

# 或者保存为ONNX格式
dummy_input = torch.randint(0, INPUT_DIM, (1, 50)).to(device)
torch.onnx.export(model, dummy_input, 'text_classifier.onnx')

6.2 构建推理API

from flask import Flask, request, jsonify
import spacy

app = Flask(__name__)

# 加载模型
model = torch.load('text_classifier_full.pt')
model.eval()

# 加载分词器
nlp = spacy.load('en_core_web_sm')

def predict_sentiment(text):
    """预测情感"""
    # 文本预处理
    text = clean_text(text)
    
    # 分词
    tokens = [token.text for token in nlp.tokenizer(text)]
    
    # 转换为索引
    indices = [TEXT.vocab.stoi.get(token, UNK_IDX) for token in tokens]
    
    # 添加batch维度
    tensor = torch.LongTensor(indices).unsqueeze(0).to(device)
    
    # 预测
    with torch.no_grad():
        prediction = model(tensor).squeeze(1)
        probability = torch.sigmoid(prediction).item()
    
    return {
        'sentiment': 'positive' if probability >= 0.5 else 'negative',
        'confidence': float(probability)
    }

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()
    text = data.get('text', '')
    
    if not text:
        return jsonify({'error': 'No text provided'}), 400
    
    result = predict_sentiment(text)
    return jsonify(result)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

6.3 API测试

import requests

# 测试API
response = requests.post(
    'http://localhost:5000/predict',
    json={'text': 'This movie is absolutely amazing! I loved every minute of it.'}
)

print(response.json())
# 输出: {'sentiment': 'positive', 'confidence': 0.985}

七、性能优化

7.1 模型压缩

# 量化模型
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)

# 保存量化后的模型
torch.save(quantized_model, 'text_classifier_quantized.pt')

7.2 推理加速

# 使用TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('text_classifier_scripted.pt')

# 加载并推理
loaded_model = torch.jit.load('text_classifier_scripted.pt')

7.3 批量推理

def batch_predict(texts):
    """批量预测"""
    processed_texts = []
    max_len = 0
    
    for text in texts:
        tokens = [token.text for token in nlp.tokenizer(clean_text(text))]
        indices = [TEXT.vocab.stoi.get(token, UNK_IDX) for token in tokens]
        processed_texts.append(indices)
        max_len = max(max_len, len(indices))
    
    # 补齐到相同长度
    padded_texts = [indices + [PAD_IDX] * (max_len - len(indices)) 
                    for indices in processed_texts]
    
    tensor = torch.LongTensor(padded_texts).to(device)
    
    with torch.no_grad():
        predictions = model(tensor).squeeze(1)
        probabilities = torch.sigmoid(predictions).cpu().numpy()
    
    return [
        {'sentiment': 'positive' if p >= 0.5 else 'negative', 'confidence': float(p)}
        for p in probabilities
    ]

八、总结与展望

8.1 完成的工作

本文实现了一个完整的文本分类系统,包括:

  1. 数据预处理:文本清洗、分词、向量化
  2. 模型构建:LSTM和Transformer两种架构
  3. 训练优化:完整的训练循环和超参数调优
  4. 模型评估:混淆矩阵和错误分析
  5. 模型部署:Flask API和性能优化

8.2 改进方向

可以从以下几个方面进一步改进:

  • 使用预训练模型:如BERT、RoBERTa等
  • 超参数调优:使用贝叶斯优化
  • 集成学习:融合多个模型的预测结果
  • 领域自适应:针对特定领域进行微调

8.3 关键经验

  1. 数据质量至关重要:好的预处理可以显著提升模型性能
  2. 从简单模型开始:先尝试简单的LSTM,再考虑复杂的Transformer
  3. 监控训练过程:使用TensorBoard等工具可视化训练过程
  4. 关注泛化能力:避免过拟合,使用正则化和早停策略

文本分类是NLP的基础任务,掌握这个任务对于理解更复杂的NLP应用至关重要。希望本文的实战经验能对你有所帮助!

#PyTorch #文本分类 #NLP #深度学习

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐