中文新闻文本分类全栈实战:从FastText到BERT知识蒸馏,20万条数据上的完整技术演进
中文新闻文本分类全栈实战:从FastText到BERT知识蒸馏,20万条数据上的完整技术演进
这是完整源代码哦,免费的
前言
在自然语言处理(NLP)领域,文本分类是最基础也是最核心的任务之一。无论是新闻推荐、舆情分析还是内容审核,都离不开高效的文本分类系统。本文将带你从传统机器学习到深度学习预训练模型,再到模型压缩部署,完整走通一个中文新闻文本分类项目的全流程。(模型蒸馏那一块好像有点问题,感兴趣的兄弟们可以帮我去GitHub上抓bug,我写那个写的有点累了,懒得管了哈哈哈)
- 注:GitHub上的项目是不包含Bert预训练模型的,因为模型有点大,上传不了,可以在我的主页看完整源文件;但本来整个项目我是蒸馏的文件夹和原来的bert文件夹都有一个bert预训练模型,但是太大了,csdn也上传不了,所以只保留了bert文件夹下的预训练模型,只需要复制bert文件夹下的bert_pretrain文件夹到bert_distil文件夹下的data下即可(和data下的data平级)
本文基于我在 GitHub 上开源的项目 text_classification,覆盖了以下技术方案:
- 🔹 TF-IDF + 随机森林(传统机器学习基线)
- 🔹 FastText(浅层神经网络 + 自动超参搜索)
- 🔹 BERT 全量微调(深层预训练模型)
- 🔹 PyTorch 动态量化(模型压缩)
- 🔹 知识蒸馏 BERT → TextCNN(大模型指导小模型)
- 🔹 Flask API 部署(所有方案均可上线推理)
GitHub 仓库地址:https://github.com/Happy-Chen-CH/text_classification
GitHub 用户名:@Happy-Chen-CH
欢迎 Star ⭐ 和 PR 交流!
目录
- 项目背景与数据集介绍
- 数据分析与预处理(EDA)
- 方案一:TF-IDF + 随机森林(基线模型)
- 方案二:FastText(轻量级利器)
- 方案三:BERT 全量微调(高精度方案)
- 方案四:模型量化(PyTorch Dynamic Quantization)
- 方案五:知识蒸馏 BERT → TextCNN
- 模型部署:Flask API 服务
- 模型对比与总结
- 代码获取与学习建议
1. 项目背景与数据集介绍
1.1 业务场景
本项目的业务场景源自今日头条新闻推荐系统中的一个功能子集:将新闻、资讯等短文本进行多分类,然后将各个类别的资讯推送到对应的推荐流中。本质上,这是一个文本分类 → 推荐系统的经典流水线。
1.2 数据集概览(这个数据比较规整,正常的数据还要自己再处理的)
数据集包含 20 万条中文新闻标题,已预标注好类别标签,共分为 10 个类别:
| 编号 | 类别 | 英文 | 训练集 | 验证集 | 测试集 |
|---|---|---|---|---|---|
| 0 | 财经 | finance | 18,000 | 1,000 | 1,000 |
| 1 | 房产 | realty | 18,000 | 1,000 | 1,000 |
| 2 | 股票 | stocks | 18,000 | 1,000 | 1,000 |
| 3 | 教育 | education | 18,000 | 1,000 | 1,000 |
| 4 | 科技 | science | 18,000 | 1,000 | 1,000 |
| 5 | 社会 | society | 18,000 | 1,000 | 1,000 |
| 6 | 政治 | politics | 18,000 | 1,000 | 1,000 |
| 7 | 体育 | sports | 18,000 | 1,000 | 1,000 |
| 8 | 游戏 | game | 18,000 | 1,000 | 1,000 |
| 9 | 娱乐 | entertainment | 18,000 | 1,000 | 1,000 |
- 训练集:180,000 条(90%)
- 验证集:10,000 条(5%)
- 测试集:10,000 条(5%)
数据格式(Tab 分隔):
中华女子学院:本科层次仅1专业招男生 3
两天价网站背后重重迷雾:做个网站究竟要多少钱 4
卡佩罗:告诉你德国脚生猛的原因 不希望英德战踢点球 7
可以看到这是一个类平衡的数据集,每个类别在训练集中各占 18,000 条(10%),不存在类别不均衡问题,这为我们后面的实验提供了一个很好的基础。
2. 数据分析与预处理(EDA)
在进行模型训练之前,我们首先对数据做探索性分析,代码位于 analysis.py。
2.1 数据加载与统计
import pandas as pd
from collections import Counter
import numpy as np
import jieba
# 读取数据
content = pd.read_csv('./data/train.txt', sep='\t')
print(content.head(10))
print(f"总样本数量: {len(content)}")
# 统计每个类别数量
count = Counter(content.label.values)
print(count)
# Counter({3: 18000, 4: 18000, 1: 18000, 7: 18000, 5: 18000,
# 9: 18000, 8: 18000, 2: 18000, 6: 18000, 0: 18000})
2.2 样本长度分析
# 统计每行样本长度
content['sentence_len'] = content['sentence'].apply(len)
length_mean = np.mean(content['sentence_len'])
length_std = np.std(content['sentence_len'])
print(f'文本平均长度: {length_mean:.2f}')
print(f'文本长度标准差: {length_std:.2f}')
# 输出: 文本平均长度: 19.21
# 文本长度标准差: 3.86
- 文本平均长度约 19 个字符,标准差约 4,说明新闻标题普遍较短且长度分布集中。
- 这一发现非常重要:对于 BERT 模型来说,这意味着我们可以将
pad_size设为 32 即可覆盖绝大多数样本,从而大幅减少计算量。
2.3 分词预处理
使用 jieba 对中文文本进行分词:
def cut_sentence(s):
return list(jieba.cut(s))
content['words'] = content['sentence'].apply(
lambda s: ' '.join(jieba.lcut(s))
)
content.to_csv('./data/train_new.csv')
同样的方式处理验证集 dev.txt 和测试集 test.txt。
3. 方案一:TF-IDF + 随机森林(基线模型)
这是整个项目的 baseline,目标是快速建立一个可用的基准模型。
3.1 技术原理
TF-IDF(词频-逆文档频率):
- TF(Term Frequency):某个词在文档中出现的频率
- IDF(Inverse Document Frequency):包含该词的文档数的倒数取对数
- TF-IDF = TF × IDF:衡量一个词对某篇文档的重要程度
核心思想:如果一个词在某篇文档中频繁出现,但在整个语料库中很少出现,则这个词对该文档具有很好的类别区分能力。
3.2 代码实现
代码位于 random_forest.py:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
from icecream import ic
# 读取数据
content = pd.read_csv('./data/train_new.csv')
corpus = content['words'].values # 分词后的文本
# 读取停用词
stop_words = open('./data/stopwords.txt').read().split()
# TF-IDF 特征提取
tfidf = TfidfVectorizer(stop_words=stop_words)
text_vectors = tfidf.fit_transform(corpus)
# 划分训练测试集
targets = content['label']
X_train, X_test, y_train, y_test = train_test_split(
text_vectors, targets, test_size=0.2, random_state=0
)
# 随机森林分类器
model = RandomForestClassifier(n_jobs=-1)
model.fit(X_train, y_train)
# 评估
accuracy = accuracy_score(y_test, model.predict(X_test))
ic(accuracy) # 输出: 0.8148 (81.48%)
3.3 结果分析
| 指标 | 值 |
|---|---|
| 准确率 (Accuracy) | 81.48% |
对于 10 分类任务来说,81.48% 的准确率是一个不错的基线。随机森林训练快速、不需要 GPU,可解释性强,适合作为快速验证的基准模型。我们有了这个 baseline 1.0,后续的所有优化都有了参照系。
4. 方案二:FastText(轻量级利器)
FastText 是 Facebook 开源的一个高效的文本分类和词向量训练工具。它的核心优势是速度极快,在工业界被广泛使用。
4.1 FastText 数据格式
FastText 要求训练数据的格式为:
__label__<类别名> <分词或分字后的文本>
我们需要将原始数据转换为这种格式,代码位于 preprocess1.py:
import jieba
# 构建 id 到 label 的映射
id_to_label = {}
with open('class.txt', 'r', encoding='utf-8') as f:
for idx, line in enumerate(f.readlines()):
id_to_label[idx] = line.strip()
# 处理训练集
train_data = []
with open('train.txt', 'r', encoding='utf-8') as f:
for line in f.readlines():
sentence, label = line.strip().split('\t')
label_name = id_to_label[int(label)]
new_label = f'__label__{label_name}'
# 使用 jieba 分词
sent_words = ' '.join(jieba.lcut(sentence))
train_data.append(f'{new_label} {sent_words}')
with open('train_fast1.txt', 'w', encoding='utf-8') as f:
for data in train_data:
f.write(data + '\n')
4.2 自动超参搜索训练
FastText 提供了 autotune 功能,可以在验证集上自动搜索最优超参数。代码位于 FastText-Train2.py:
import fasttext
import time
train_path = "./data/train_fast1.txt"
dev_path = "./data/dev_fast.txt"
test_path = "./data/test_fast.txt"
# 训练:autotune 自动搜索最优超参数
model = fasttext.train_supervised(
input=train_path,
autotuneValidationFile=dev_path,
autotuneDuration=300, # 随机搜索 300 秒
wordNgrams=2,
verbose=3
)
# 测试
result = model.test(test_path)
print(result) # (10000, 0.9172, 0.9172)
# 保存模型
model.save_model(f"./fasttext_model_{int(time.time())}.bin")
4.3 超参数搜索空间
autotune 自动调节的超参数包括:
| 参数 | 说明 | 默认值 |
|---|---|---|
| lr | 学习率 | 0.1 |
| dim | 词向量维度 | 100 |
| ws | 上下文窗口大小 | 5 |
| epoch | 训练轮数 | 5 |
| minCount | 最低词频 | 5 |
| wordNgrams | n-gram 设置 | 1 |
4.4 结果分析
| 指标 | 值 |
|---|---|
| 精确率 (Precision) | 91.72% |
| 召回率 (Recall) | 91.72% |
| 相比随机森林提升 | +10.24% |
FastText 相比随机森林有了大幅度的提升,精确率和召回率均达到 91.72%。而且训练速度非常快(几分钟即可完成),模型推理也极快(单条推理约 5ms)。
💡 工业场景启示:FastText 在 GPU 环境下单条预测仅需约 5ms,这就是它在工业界备受欢迎的原因——轻量、高效、够用。
5. 方案三:BERT 全量微调(高精度方案)
BERT(Bidirectional Encoder Representations from Transformers)自 2018 年提出以来,已经成为 NLP 领域的事实标准。本项目使用中文 BERT 预训练模型进行微调。
5.1 预训练模型
使用的是 Google 发布的中文 BERT Base 模型:
- 层数:12 层 Transformer
- 隐藏层维度:768
- 注意力头数:12
- 词表大小:21,128
- 总参数量:约 110M
5.2 模型定义
代码位于 Bert_project/src/models/bert.py:
from transformers import BertModel, BertTokenizer, BertConfig
class Config(object):
def __init__(self):
self.model_name = "bert"
self.pad_size = 32 # 序列长度(根据EDA结果设定)
self.batch_size = 128
self.epoches = 3
self.learning_rate = 5e-5
self.num_classes = 10
self.hidden_size = 768
# 设备自适应:支持 CUDA / MPS (Apple Silicon) / CPU
self.device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
context = x[0] # 输入的 token ids
mask = x[2] # attention mask
_, pooled = self.bert(context, attention_mask=mask, return_dict=False)
out = self.fc(pooled)
return out
亮点:
- 使用
BertModel.from_pretrained()加载预训练权重 - 简单高效的分类头设计:BERT Encoder → Linear(768, 10)
- 支持 CUDA / MPS (Apple Silicon) / CPU 多设备自动适配
5.3 训练与评估函数
代码位于 Bert_project/src/train_eval.py:
训练函数核心逻辑:
def train(config, model, train_iter, dev_iter):
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
dev_best_loss = float("inf")
model.train()
for epoch in range(config.epoches):
for i, (trains, labels) in enumerate(train_iter):
model.zero_grad()
outputs = model(trains)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
# 每 100 个 batch 评估一次
if total_batch % 100 == 0:
dev_acc, dev_loss = evaluate(config, model, dev_iter)
if dev_loss < dev_best_loss:
torch.save(model.state_dict(), config.save_path)
improve = "*" # 标记模型有提升
评估函数:计算准确率、分类报告、混淆矩阵。
5.4 实验结果
Test Loss: 0.2, Test Acc: 93.64%
Precision, Recall and F1-Score...
precision recall f1-score support
finance 0.9246 0.9320 0.9283 1000
realty 0.9484 0.9370 0.9427 1000
stocks 0.8787 0.8980 0.8882 1000
education 0.9511 0.9730 0.9619 1000
science 0.9236 0.8950 0.9091 1000
society 0.9430 0.9270 0.9349 1000
politics 0.9267 0.9100 0.9183 1000
sports 0.9780 0.9780 0.9780 1000
game 0.9514 0.9600 0.9557 1000
entertainment 0.9390 0.9540 0.9464 1000
----------------------------------------------
accuracy 0.9364 10000
macro avg 0.9365 0.9364 0.9364 10000
5.5 结果分析
| 指标 | 值 | 相比 FastText 提升 |
|---|---|---|
| 准确率 | 93.64% | +1.92% |
| Macro F1 | 93.64% | +1.91% |
BERT 模型在测试集上达到了 93.64% 的准确率,相比 FastText 的 91.72% 有了显著性提升。值得注意的是:
- 体育类别 F1 最高(97.80%),区分度最好
- 股票类别 F1 最低(88.82%),容易和财经、科技混淆
- 从混淆矩阵可以看出,“股票"常被误判为"财经”(49例)——这在实际场景中也是合理的
6. 方案四:模型量化(PyTorch Dynamic Quantization)
BERT 虽然精度高,但模型体积 390MB,对部署不太友好。PyTorch 提供动态量化功能,可以将模型参数从 float32 压缩到 int8。
6.1 量化原理
简单理解:用更少的比特位(int8,8位)代替较多的比特位(float32,32位),从而缩减模型体积并加速推理。
- 原始模型:float32 精度 → 像素高,看得清晰
- 量化模型:int8 精度 → 像素低,但依然可以准确识别
6.2 代码实现
代码位于 Bert_project/src/run1.py:
import torch
from models.bert import Config, Model
config = Config()
config.device = 'cpu' # ⚠️ 量化必须在 CPU 上执行
model = Model(config)
model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
# 🔥 核心:PyTorch 动态量化
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear}, # 指定要量化的层类型
dtype=torch.qint8 # 量化目标精度
)
# 测试量化后模型
test(config, quantized_model, test_iter)
# 保存量化模型
torch.save(quantized_model, config.save_path2)
6.3 量化结果
| 对比维度 | 原始 BERT | 量化 BERT | 变化 |
|---|---|---|---|
| 模型大小 | 390MB | 146MB | ↓ 62.6% |
| 测试准确率 | 93.64% | 91.92% | ↓ 1.72% |
结论:模型体积缩减了 256.6MB(62.6%),而准确率仅下降不到 2 个百分点,效果非常优异!这说明 BERT 模型的鲁棒性非常高,即使在 int8 精度下也能保持优秀的分类能力。
7. 方案五:知识蒸馏 BERT → TextCNN
在工业级应用中,除了要求模型效果好之外,还希望其「消耗」足够小。知识蒸馏(Knowledge Distillation)就是解决这一问题的重要方法。
7.1 核心原理
知识蒸馏由 Hinton 于 2015 年提出,核心思想是:用一个复杂的教师模型(Teacher)指导一个简单的学生模型(Student)学习。
蒸馏损失函数:
L K D = α ⋅ T 2 ⋅ K L ( p s T ∥ p t T ) + ( 1 − α ) ⋅ C E ( p s , y ) L_{KD} = \alpha \cdot T^2 \cdot KL(p_s^T \parallel p_t^T) + (1-\alpha) \cdot CE(p_s, y) LKD=α⋅T2⋅KL(psT∥ptT)+(1−α)⋅CE(ps,y)
其中:
- 软目标损失:学生模型输出与教师模型输出的 KL 散度(温度 T T T 平滑后的软标签)
- 硬目标损失:学生模型输出与真实标签的交叉熵
- α = 0.8 \alpha = 0.8 α=0.8:软损失权重
- T = 2 T = 2 T=2:温度参数,分布越平缓保留的相似信息越多
7.2 教师模型(BERT)
使用训练好的 BERT 模型作为教师,参数量 110M,准确率 93.64%。
7.3 学生模型(TextCNN)
代码位于 Bert_distil/src/models/textCNN.py:
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
# 字符级词嵌入层
self.embedding = nn.Embedding(config.n_vocab, config.embed,
padding_idx=config.n_vocab - 1)
# 多尺度卷积核 (2, 3, 4)
self.convs = nn.ModuleList([
nn.Conv2d(1, config.num_filters, (k, config.embed))
for k in config.filter_sizes
])
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(
config.num_filters * len(config.filter_sizes),
config.num_classes
)
def conv_and_pool(self, x, conv):
x = F.relu(conv(x)).squeeze(3) # 卷积 + ReLU
x = F.max_pool1d(x, x.size(2)).squeeze(2) # 最大池化
return x
def forward(self, x):
out = self.embedding(x[0])
out = out.unsqueeze(1)
out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
out = self.dropout(out)
out = self.fc(out)
return out
TextCNN 采用多尺度卷积核 (2, 3, 4),分别捕捉二词、三词、四词的局部特征,参数量仅约 2M。
7.4 蒸馏训练
代码位于 Bert_distil/src/train_eval.py:
# 步骤1:获取教师模型对所有训练样本的软标签
def fetch_teacher_outputs(teacher_model, train_iter):
teacher_model.eval()
teacher_outputs = []
with torch.no_grad():
for data_batch, _ in train_iter:
outputs = teacher_model(data_batch)
teacher_outputs.append(outputs)
return teacher_outputs
# 步骤2:定义蒸馏损失函数
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.8, T=2):
# 学生网络的 log_softmax(带温度)
output_student = F.log_softmax(outputs / T, dim=1)
# 教师网络的 softmax(带温度)
output_teacher = F.softmax(teacher_outputs / T, dim=1)
# 软目标损失(KL散度)
soft_loss = nn.KLDivLoss()(output_student, output_teacher)
# 硬目标损失(交叉熵)
hard_loss = F.cross_entropy(outputs, labels)
# 加权总损失(soft_loss 乘以 T² 补偿梯度)
return soft_loss * alpha * T * T + hard_loss * (1.0 - alpha)
# 步骤3:蒸馏训练
def train_kd(cnn_config, bert_model, cnn_model,
bert_train_iter, cnn_train_iter, cnn_dev_iter, cnn_test_iter):
bert_model.eval()
teacher_outputs = fetch_teacher_outputs(bert_model, bert_train_iter) # 离线预计算软标签
cnn_model.train()
for epoch in range(cnn_config.num_epochs):
for i, (trains, labels) in enumerate(cnn_train_iter):
outputs = cnn_model(trains)
loss = loss_fn_kd(outputs, labels, teacher_outputs[i])
loss.backward()
optimizer.step()
⚠️ 重要技巧:引入温度 T 后,软目标产生的梯度会缩小为原来的 1 / T 2 1/T^2 1/T2,因此需要在软损失项乘以 T 2 T^2 T2 来补偿梯度。
7.5 蒸馏结果对比
| 维度 | BERT (Teacher) | TextCNN 基础 | TextCNN 调参 | 结论 |
|---|---|---|---|---|
| 参数量 | 110M | ~2M | ~5M | 显著减少 |
| 模型大小 | 409.2MB | 11.3MB | 23.1MB | 缩小 17.7 倍 |
| 测试准确率 | 93.64% | 89.89% | 91.25% | 仅下降 2.39% |
| 卷积核 | — | (2,3,4) | (2,3,4,5) | 增加感受野 |
| 卷积核数 | — | 256 | 1024 | 增加通道 |
| 训练轮数 | 2 | 3 | 30 | 充分收敛 |
7.6 关键发现
- 模型压缩效果显著:23.1MB 的 TextCNN 模型大小仅为原始 BERT 的 5.65%,缩小了 17.7 倍
- 精度损失可控:经过调参后准确率仅从 93.64% 下降到 91.25%,下降 2.39 个百分点
- 部署友好:小模型更适合移动端、边缘设备等资源受限场景
- 调参有收益:增加卷积核尺寸到 (2,3,4,5)、卷积核数到 1024、训练轮数到 30,准确率从 89.89% 提升到 91.25%
8. 模型部署:Flask API 服务(服务端和客户端记得分两个端口运行,我一开始就是运行到一个端口了,抓了1小时bug,最后才发现是这个问题)
一个好的模型不仅要训得好,更要部署得上。本项目三个方案均提供了 Flask RESTful API 推理服务。
8.1 FastText 服务端
from flask import Flask, request
import jieba
import fasttext
app = Flask(__name__)
model = fasttext.load_model('./model/fasttext_model.bin')
@app.route('/v1/main_server/', methods=["POST"])
def main_server():
uid = request.form['uid']
text = request.form['text']
input_text = ' '.join(jieba.lcut(text))
res = model.predict(input_text)
return res[0][0]
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
8.2 BERT 服务端
from flask import Flask, request
import torch
from models.bert import Config, Model
app = Flask(__name__)
config = Config()
model = Model(config).to(config.device)
model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
@app.route('/predict', methods=["POST"])
def predict():
text = request.json['text']
# 分词 + 填充 + 推理
result = inference(model, config, text)
return {'category': result}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
8.3 客户端调用
import requests
url = "http://127.0.0.1:5000/v1/main_server/"
data = {"uid": "test-001", "text": "雷佳音获飞天奖"}
resp = requests.post(url, data=data)
print(resp.text) # __label__entertainment
8.4 推理性能
| 方案 | 单条推理耗时 |
|---|---|
| FastText | ~5ms |
| BERT (量化) | ~120ms |
| BERT (原始) | ~180ms |
| TextCNN (蒸馏) | ~10ms |
9. 模型对比与总结
9.1 全方案对比
| 方法 | 准确率 | 模型大小 | 推理速度 | 参数量 | 适用场景 |
|---|---|---|---|---|---|
| 🎯 随机森林 | 81.48% | — | ⚡ 极快 | — | 基线模型/快速验证 |
| 🚀 FastText | 91.72% | ~370MB | ⚡ 快 (~5ms) | — | 实时在线服务 |
| 🧠 BERT | 93.64% | 390MB | 🐢 慢 (~180ms) | 110M | 高精度离线分析 |
| 📦 BERT (量化) | 91.92% | 146MB | 🚀 较快 | 110M | 精度与速度折中 |
| 💡 TextCNN (蒸馏) | 91.25% | 23.1MB | ⚡ 快 (~10ms) | ~5M | 移动端/边缘部署 |
9.2 技术演进路线
随机森林 (81.48%) ← 传统 ML 基线,快速可解释
↓
FastText (91.72%) ← 浅层神经网络,速度快,提升 +10%
↓
BERT 微调 (93.64%) ← 深层预训练,高精度,SOTA 水平
↓
BERT 量化 (91.92%) ← 模型压缩 62.6%,精度下降 < 2%
↓
知识蒸馏 (91.25%) ← 模型缩小 17.7 倍,速度提升 18 倍
9.3 技术栈一览
| 层级 | 使用技术 |
|---|---|
| 数据层 | Pandas、NumPy、jieba 分词 |
| 特征层 | TF-IDF、停用词过滤、字符级/词级分词 |
| 模型层 | RandomForest、FastText、BERT、TextCNN |
| 优化层 | AdamW、autotune、动态量化、知识蒸馏 |
| 部署层 | Flask RESTful API、PyTorch model serving |
| 设备层 | CUDA / MPS (Apple Silicon) / CPU 自适应 |
10. 代码获取与学习建议
🔗 获取方式
📦 GitHub 仓库:https://github.com/Happy-Chen-CH/text_classification
👤 GitHub 主页:https://github.com/Happy-Chen-CH
如果觉得有帮助,请给个 Star ⭐ 鼓励一下~
📁 项目结构
text_classification/
├── randomforest_and_fasttext/ # 传统 ML 方案
│ ├── analysis.py # 数据 EDA 分析
│ ├── random_forest.py # TF-IDF + 随机森林
│ ├── FastText-Train2.py # FastText + autotune 训练
│ ├── app.py # Flask 推理服务
│ └── test.py # API 测试客户端
│
├── Bert_project/ # BERT 深度方案
│ ├── src/models/bert.py # BERT 模型定义 + 配置
│ ├── src/train_eval.py # 训练/评估/测试函数
│ ├── src/run.py # 标准训练入口
│ ├── src/run1.py # 训练 + 动态量化
│ └── src/app.py # Flask 推理服务
│
└── Bert_distil/ # 知识蒸馏方案
├── src/models/
│ ├── bert.py # BERT 教师模型
│ └── textCNN.py # TextCNN 学生模型
├── src/train_eval.py # 蒸馏损失 + 训练函数
└── src/run.py # 蒸馏训练入口
📖 学习建议
建议按以下顺序循序渐进地学习:
- 第一步:从
randomforest_and_fasttext/开始,理解 TF-IDF + 随机森林的基线思路 - 第二步:学习 FastText 的数据预处理和 autotune 自动调参
- 第三步:进入
Bert_project/,理解 BERT 微调、模型量化 - 第四步:挑战
Bert_distil/,掌握知识蒸馏的核心原理和实现 - 第五步:学习各方案的 Flask 部署代码,打通训练→部署的全链路
⚠️ 注意事项
- 路径配置:项目中部分代码包含硬编码路径,使用前请根据实际路径修改
- 模型文件:预训练 BERT 权重 (~393MB) 体积较大,需自行下载或通过 Git LFS 管理
- 设备兼容:BERT 项目已针对 Apple Silicon (MPS) 做适配
- 蒸馏方案:该部分代码可能还有一些小问题,欢迎 PR 一起完善!
写在最后
本文完整展示了一个中文文本分类项目从数据探索 → 基线建立 → 深度学习微调 → 模型压缩 → 部署上线的全流程。希望这篇文章能够帮助正在学习 NLP 的朋友们建立起一个完整的知识框架。(最后的bert_distil那个文件夹里的代码好像有点问题,欢迎大家来修改,另外两个文件夹应该都是没什么问题的)
如果你有任何问题或建议,欢迎在 GitHub 上提 Issue 或 PR,我们一起讨论进步!
如果觉得本文对你有帮助,欢迎点赞 👍、收藏 ⭐、转发 🔄,也欢迎关注我的 CSDN 博客获取更多技术分享!
本文首发于 CSDN,作者:Happy-Chen-CH,转载请注明出处。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)