【深度学习基础篇11】从CT影像到医学报告:用BART实现医学文本生成的工程实践
文章目录
项目概览
项目背景与意义
在现代医疗体系中,医学影像(如CT、MRI)是临床诊断的重要依据。医生需要根据影像表现,撰写一份结构化的“印象/结论”报告,这份报告不仅是后续诊疗的核心依据,也直接影响患者的治疗方案。
然而,这一过程存在显著痛点:
- 效率低下:一位影像科医生每天需要处理大量影像,手动撰写报告耗时耗力,且容易因疲劳导致疏漏。
- 标准化不足:不同医生的表述习惯和专业侧重不同,导致报告风格和质量参差不齐,不利于后续的信息提取与分析。
- 资源分配不均:优质的影像科医生资源高度集中在大型医院,基层医疗机构的诊断能力亟待提升。
本项目正是为了解决这一痛点:利用深度学习技术,特别是序列到序列(Seq2Seq)的生成式模型,自动根据CT影像的结构化描述,生成符合规范的医学报告“印象”部分。
这不仅能大幅提升医生的工作效率,还能通过AI辅助,为基层医疗提供标准化、高质量的诊断支持,是AI赋能医疗的典型应用场景。
任务本质:序列到序列的文本生成
从机器学习的视角看,这个项目的本质非常清晰:
- 输入(Source):一段由数字序列表示的CT影像结构化描述。例如:
14 108 28 30 15 13 294 ...。这些数字并非随机,而是对应着医学术语的编码(如ICD编码或自定义词表),代表了“肝脏大小正常”、“胆囊无扩张”等具体发现。 - 输出(Target):一段自然语言的医学结论(Impression)。例如:“剖宫产术后所见,右下腹大片状及类圆形混杂密度影,出血?请密切结合临床及实验室检查…”。
这是一个典型的序列到序列(Sequence-to-Sequence, Seq2Seq)文本生成任务。我们的目标,就是训练一个模型,学习从“编码化的医学发现”到“自然语言医学结论”的映射关系。
核心技术选型:BART模型
在众多Seq2Seq模型中,我们选择了 BART (Bidirectional and Auto-Regressive Transformers) 作为核心架构,原因如下:
- Encoder-Decoder 架构:BART 采用了标准的 Encoder-Decoder 结构。Encoder 负责理解和编码输入的医学发现序列,捕捉其中的复杂关系;Decoder 则根据 Encoder 的输出,自回归地生成流畅、准确的医学结论。这与我们“理解影像→生成报告”的任务逻辑完美匹配。
- 预训练优势:BART 是在大规模文本语料上预训练得到的,具备强大的语言理解和生成能力。通过在中文医学数据上进行微调(Fine-tuning),它能快速适应医学文本的专业术语和行文风格。
- 灵活性与成熟度:[[Hugging Face Transformers]] 库对 BART 提供了极佳的支持,使得我们可以快速地进行模型加载、训练和部署,大大降低了工程实现的难度。
在本项目中,我们使用了 mybart-base-chinese 这一预训练好的中文 BART 模型作为起点。
项目目录结构解析
首先,先看看我们的项目目录,可以划分为几个核心模块:
bart/
├── data/ # 数据目录,存放所有训练、验证、测试数据
│ ├── train.csv # 训练集:输入序列 -> 输出序列
│ ├── test.csv # 测试集:只有输入序列,用于生成预测
│ └── ...
├── model_utils/ # 核心工具模块,封装了模型、数据、配置等
│ ├── models.py # 模型定义与加载
│ ├── data.py # 数据加载与预处理
│ ├── config.py # 全局配置
│ └── ...
├── mybart-base-chinese/ # 预训练的中文BART模型权重和配置
├── save/ # 训练好的模型和日志保存目录
├── finetune.py # 核心训练脚本
├── inference.py # 推理/预测脚本
└── ...
【看懂数据】解码医学文本生成的数据源:CT 影像编码与报告的映射关系
在进入模型训练和代码解析之前,我们首先要搞懂数据到底是什么——这是所有AI项目的根基。本项目的核心是“CT影像编码序列 → 医学报告结论”的映射,而 train.csv 和 test.csv 就是承载这一映射关系的核心数据源。
数据格式直观分析
基础格式对比
先看两份文件的核心差异:
| 文件 | 格式 | 核心区别 |
|---|---|---|
train.csv |
行号,输入序列,输出序列 |
有输入+输出,用于模型训练 |
test.csv |
行号,输入序列 |
只有输入,用于模型预测 |
数据样例拆解(以train.csv第0行为例)
0,14 108 28 30 15 13 294 29 20 18 23 21 25 32 16 14 39 27 14 47 46 69 70 11 24 42 26 37 61 24 10 79 46 62 19 13 31 95 19 28 20 18 10 22 12 38 41 17 23 21 36 53 25 10,22 12 38 41 17 81 10
- 行号:
0→ 样本唯一标识 - 输入序列:
14 108 28 30 ... 36 53 25 10→ CT影像的结构化编码(核心输入) - 输出序列:
22 12 38 41 17 81 10→ 医学报告结论的编码(模型要生成的目标)
关键观察结论
- 编码化存储:无论输入还是输出,都不是自然语言,而是数字序列——这是医学文本生成的典型处理方式(将专业术语映射为数字,降低模型学习难度)。
- 序列长度差异:输入序列通常很长(几十到上百个数字),输出序列较短(几个到几十个数字),符合“详细影像描述 → 精简结论”的医疗逻辑。
- 数字复用性:部分数字(如
22 12 38 41 17)同时出现在输入和输出中,说明这些编码对应“核心医学发现”,是从影像到结论的关键映射。
数据背后的医学逻辑
这些看似随机的数字,本质是医学术语/影像特征的“词典编码”,我们可以用通俗的逻辑拆解:
数字编码的本质
| 层级 | 示例 | 对应关系 |
|---|---|---|
| 原始医学文本 | “肝脏大小形态正常,未见明显异常密度影” | 医生手写的原始报告 |
| 术语分词 | 肝脏/大小形态/正常/未见/异常密度影 | 对原始文本的专业分词 |
| 数字编码 | 22/12/38/41/17 | 每个术语映射为唯一数字 |
输入序列的构成逻辑
输入序列的长数字串,对应CT影像的结构化描述维度,例如:
14 108 28 30→ 可能对应“扫描部位:腹部,层厚:5mm,窗宽:350,窗位:50”15 13 294 29→ 可能对应“肝脏:形态规则,实质密度均匀,无占位性病变”20 18 23 21→ 可能对应“胆囊:大小正常,壁无增厚,腔内无结石”
输出序列的构成逻辑
输出序列的短数字串,对应医学报告“印象”部分的核心结论,例如:
22 12 38 41 17 81 10→ 解码后可能是“腹部CT平扫:未见明显异常”35 48 49 150 167 308 282 10→ 解码后可能是“右肺下叶小结节,建议随访”
注:具体数字对应的医学术语,需要结合项目的
vocab.txt(词表文件)才能精准解码——这是我们后续解析model_utils/data.py时的核心内容。
数据质量与分布特征
从样例数据中,我们还能发现几个关键特征:
- 完整性:训练集每个样本都有“输入-输出”成对数据,且无空值、无乱码,数据清洗程度高。
- 多样性:输入序列的长度和数字组合差异大,覆盖了不同部位、不同病变类型的CT影像,保证了模型的泛化能力。
- 终止符特征:几乎所有序列都以
10结尾——这是典型的“EOS(End of Sequence)终止符”,告诉模型“序列到此结束”。 - 高频数字:
10(终止符)、11、12、14、22、34等数字高频出现,对应医学报告中最常用的基础术语(如“CT平扫”、“未见异常”、“建议随访”)。
数据预处理的核心目标(提前铺垫)
理解了数据格式后,我们就能明确后续数据预处理的核心任务:
- 构建词表(Vocab):建立“数字 ↔ 医学术语”的映射字典。
- 序列标准化:将不同长度的序列统一为固定长度(或使用padding),适配模型输入。
- 数据加载:将CSV文件转换为模型可处理的Tensor格式。
- 数据划分:(若需要)将train.csv进一步划分为训练集和验证集,用于监控模型过拟合。
数据篇核心总结
- 本项目的数据源是编码化的CT影像描述 + 编码化的医学结论,而非原始文本,这是医疗AI项目的典型特征(兼顾专业性和模型效率)。
train.csv是“输入-输出”成对的监督数据,是模型学习的核心;test.csv只有输入,用于验证模型的生成能力。- 数字序列的长度差异、高频数字、终止符
10等特征,是后续数据预处理和模型设计的关键依据。
【数据预处理】从 CSV 到模型输入:医学序列数据的加载与标准化
理解了数据源的格式和含义后,接下来的核心任务是将原始CSV文件转换为模型能直接处理的格式。本项目中 process_data.py 和 model_utils/data.py 这两个文件,正是完成这一核心转换的关键——前者负责数据划分,后者负责数据加载和标准化。
process_data.py:训练集与验证集的划分
这是整个数据预处理的第一步,核心目标是将原始的 train.csv 划分为训练集和验证集,用于模型的训练和效果监控。
核心代码解析
import pandas as pd #处理表格数据
pre_train_file= "data/train.csv"
# 1. 读取原始训练数据
train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #读入数据
print(train_df.head())
# 2. 随机采样90%的数据作为训练集
train_data = train_df.sample(frac=0.9, random_state=0, axis=0) #采样0.9的比例
# 3. 剩余10%的数据作为验证集(~表示取反,筛选不在训练集索引中的数据)
val_data = train_df[~train_df.index.isin(train_data.index)] #筛选训练集之外的数据
# 4. 保存划分后的数据
train_data.to_csv("data/pro_train_data.csv", index=False,header=False)
val_data.to_csv("data/pro_val_data.csv", index=False,header=False)
关键知识点拆解
- 数据读取:
header=None:原始CSV文件没有列名,需要手动指定names=["id","input","tgt"]:为三列数据命名(行号、输入序列、输出序列)
- 数据划分逻辑:
sample(frac=0.9, random_state=0):随机抽取90%的数据作为训练集,random_state=0保证划分结果可复现~train_df.index.isin(train_data.index):这是Pandas的经典用法——筛选出索引不在训练集中的行,作为验证集
- 保存规则:
index=False:不保存Pandas的行索引(避免额外列干扰模型读取)header=False:不保存列名(保持和原始数据一致的格式)
为什么要划分验证集?
- 监控模型训练过程中的过拟合问题:如果训练集准确率持续上升,但验证集准确率下降,说明模型过拟合
- 调整模型参数:通过验证集的效果,优化学习率、批次大小等超参数
- 评估模型的泛化能力:验证集是未参与训练的数据,能更真实反映模型效果
data.py:数据加载与序列标准化
这是数据预处理的核心文件,基于PyTorch的 Dataset 类自定义了 TranslationDataset,实现了从“数字序列”到“模型输入Tensor”的完整转换。
核心类:TranslationDataset
from torch.utils.data import Dataset, DataLoader
import numpy as np
import csv
from transformers import AutoTokenizer
class TranslationDataset(Dataset):
def __init__(self, data_file, args):
# 1. 读取CSV文件并加载样本([:16]是调试用,实际应去掉,使用全部数据)
with open(data_file, 'r') as fp:
reader = csv.reader(fp)
self.samples = [row for row in reader][:16]
# 2. 从参数中获取关键配置
self.input_l = args.input_l #输入序列最大长度
self.output_l = args.output_l #输出序列最大长度
self.sos_id = args.sos_id #开始标记token id
self.pad_id = args.pad_id #填充token id
self.eos_id = args.eos_id #结束标记token id
self.tgt_pad_id = args.tgt_pad_id #输出序列填充token id
# 3. 加载预训练模型的Tokenizer
self.tk=AutoTokenizer.from_pretrained(args.pre_model_path)
def __len__(self):
# 返回样本总数
return len(self.samples)
def __getitem__(self, idx):
# 核心:处理单个样本,转换为模型可接受的格式
# 步骤1:处理输入序列
# - 开头加SOS(开始标记),结尾加EOS(结束标记)
# - convert_tokens_to_ids:将数字字符串转换为对应的id(这里数字本身就是id)
source =[self.sos_id]+ self.tk.convert_tokens_to_ids([x for x in self.samples[idx][1].split()]) + [self.eos_id]
# 步骤2:输入序列padding(填充到指定长度)
if len(source)<self.input_l:
source.extend([self.pad_id] * (self.input_l-len(source)))
# 步骤3:处理测试集(只有输入,无输出)
if len(self.samples[idx])<3:
return np.array(source)[:self.input_l]
# 步骤4:处理输出序列(和输入序列逻辑一致)
target = [self.sos_id] + self.tk.convert_tokens_to_ids([x for x in self.samples[idx][2].split()]) + [self.eos_id]
if len(target)<self.output_l:
target.extend([self.tgt_pad_id] * (self.output_l-len(target)))
# 步骤5:返回处理后的输入和输出(截断到最大长度)
return np.array(source)[:self.input_l], np.array(target)[:self.output_l]
关键逻辑拆解
- 序列标准化的核心步骤:
| 处理阶段 | 操作 | 目的 |
|---|---|---|
| 序列补全 | 开头加SOS,结尾加EOS |
告诉模型序列的开始和结束位置 |
| 长度统一 | 短序列用PAD填充,长序列截断 |
适配模型固定长度的输入要求 |
| 格式转换 | 列表→Numpy数组 | 便于后续转换为PyTorch Tensor |
-
Tokenizer的特殊作用:
- 这里的
AutoTokenizer并非传统意义上的“分词”,而是利用convert_tokens_to_ids方法,将数字字符串(如"22")转换为整数id(22) - 因为数据本身已经是编码后的数字序列,无需额外分词,这是医学文本生成的特殊处理方式
- 这里的
-
训练集/验证集 vs 测试集的区别处理:
- 训练集/验证集:返回
(source, target)(输入+输出) - 测试集:只有输入序列,返回
source即可
- 训练集/验证集:返回
create_dataloaders函数:构建数据加载器
def create_dataloaders(args, test=False):
if not test:
# 训练/验证模式:加载划分后的训练集和验证集
train_data_path = args.data_path+"/pro_train_data.csv"
val_data_path = args.data_path + "/pro_val_data.csv"
train_data = TranslationDataset(train_data_path, args)
valid_data = TranslationDataset(val_data_path, args)
# 构建DataLoader(自动批量加载数据)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False)
valid_loader = DataLoader(valid_data, batch_size=args.val_batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False)
return train_loader, valid_loader
else:
# 测试模式:只加载测试集
test_data_path = args.data_path + "/preliminary_a_test.csv"
test_data = TranslationDataset(test_data_path, args)
test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False)
return test_loader
核心参数说明
batch_size:每次加载的样本数(训练时通常设为16/32/64)shuffle=True:训练集打乱顺序(避免模型学习到序列顺序的偏差),测试集shuffle=False(保证结果可复现)num_workers:多线程加载数据(加速数据读取)drop_last=False:保留最后一个不足批次大小的样本(避免数据浪费)
数据预处理完整流程总结
数据预处理篇核心总结
process_data.py的核心是数据划分:将原始训练集按9:1拆分为训练集和验证集,保证模型训练的可监控性。data.py的核心是序列标准化:通过添加SOS/EOS标记、Padding填充,将长度不一的数字序列转换为模型可处理的固定长度格式。- 整个预处理流程遵循“数据划分→加载→标准化→批量加载”的经典范式,是PyTorch处理序列数据的标准做法。
【数据预处理–扩展】适配医学场景的BART词表扩展与重构
在之前的解析中,我们提到“数字编码对应医学术语”,但未深入讲解词表的构建过程。pro_vocab.py 是项目中定制化词表的核心脚本,负责从原始数据中提取医学编码、扩展BART模型词表、适配医学文本生成场景
词表的核心作用与项目背景
词表(vocab.txt)的本质
词表是“token(标记)→ id(数字)”的映射字典,是模型理解文本的基础:
- 通用BART模型的词表:包含通用中文词汇(如“的”、“是”、“我”),但缺乏医学专业术语/编码
- 本项目的特殊需求:模型需要处理医学编码数字(如“22”、“38”),而非自然语言词汇
项目中vocab.txt的特征
- 开头是大量
[unused*]:通用BART模型预留的未使用token,用于扩展自定义词汇 - 中间/末尾是生僻字符/医学编码:项目中实际使用的CT影像编码数字对应的占位符
pro_vocab.py:词表定制化全流程解析
该脚本的核心目标是:从训练数据中提取所有医学编码,扩展BART模型的词表,并更新模型的embedding层以适配新词表。
完整代码逐段解析
import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args
# 1. 加载配置参数
args = parse_args() # 读取config.py中的参数配置
# 2. 加载原始数据,提取所有token
def load_data(path):
with open(path, 'r', encoding='utf-8') as f:
lines = f.readlines()
datas = []
for line in lines:
line = line.strip().split(",")
if len(line) == 3:
# 训练集:合并输入序列和输出序列的token
text, target = line[1].split(" "), line[2].split(" ")
datas.append(text + target)
else:
# 测试集:仅提取输入序列的token
text = line[1].split(" ")
datas.append(text)
return datas
# 加载train.csv中的所有token
train_data = load_data('./data/train.csv')
# 3. 统计token出现频率(Counter是高效的计数工具)
token2count = Counter() # 初始化计数字典
for i in train_data:
token2count.update(i) # 统计每个token(医学编码)的出现次数
# 4. 筛选token,构建自定义词表
tail = []
ct = 0 # 频率阈值:设为0表示保留所有token(无过滤)
for k, v in token2count.items():
if v >= ct:
tail.append(k)
tail.sort() # 按编码数字排序,保证词表顺序稳定
vocab = tail
# 5. 插入特殊token(适配BART模型的要求)
vocab.insert(0,"[PAD]") # 填充token,id=0
vocab.insert(100,"[UNK]") # 未知token,id=100
vocab.insert(101,"[CLS]") # 分类token,id=101(对应config中的sos_id=101)
vocab.insert(102,"[SEP]") # 分隔token,id=102(decoder起始token)
vocab.insert(103,"[MASK]") # 掩码token,id=103
vocab.insert(104,"[EOS]") # 结束token,id=104(对应config中的eos_id=105,此处需注意配置一致性)
# 6. 扩展模型词表(注释部分是通用BART词表扩展逻辑,本项目直接使用自定义词表)
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab() #获取模型原始词表
# new_vocabs = list(vocabs.keys())
# count = 0
# for v in vocab: #遍历自定义token,补充到原始词表
# if v not in vocabs:
# count += 1
# new_vocabs.append(v)
new_vocabs = vocab # 本项目直接使用自定义词表(医学编码)
# 7. 保存新的vocab.txt到模型目录
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:
for v in new_vocabs:
f.write(f"{v}\n") # 按行保存token,每行对应一个id
# 8. 更新BART模型的embedding层,适配新词表
model = BartForConditionalGeneration.from_pretrained(args.pre_model_path) # 加载原始BART模型
model.resize_token_embeddings(len(new_vocabs)) # 调整token embedding层的大小(核心!)
# 9. 保存更新后的模型权重和配置
state_dict = model.state_dict() # 获取模型参数
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin') # 保存权重
# 更新配置文件中的词表大小
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs) # 同步更新vocab_size参数
bartconfig.save_pretrained(args.pre_model_path) # 保存新配置
关键步骤拆解
数据加载与token提取
load_data函数:解析CSV文件,提取所有输入/输出序列的医学编码token(如“14”、“108”、“22”)。- 合并输入+输出序列:保证词表包含所有模型需要处理的token(输入的影像编码+输出的报告编码)。
token统计与筛选
Counter:高效统计每个token的出现次数,是处理文本数据的常用工具。- 频率阈值
ct=0:保留所有token(医学编码无低频噪声,无需过滤)。 tail.sort():对编码数字排序,保证词表顺序固定,避免每次运行生成不同的id映射。
特殊token插入
插入BART模型必需的特殊token,且指定固定id,和config.py中的参数对应:
| 特殊token | 插入位置 | 对应id | config.py参数 | 作用 |
|---|---|---|---|---|
| [PAD] | 0 | 0 | pad_id=0 | 序列填充 |
| [CLS] | 101 | 101 | sos_id=101 | 序列开始 |
| [EOS] | 104 | 104 | eos_id=105 | 序列结束(注:此处id需和config对齐) |
| [SEP] | 102 | 102 | decoder_start_token_id=102 | 解码器起始标记 |
模型词表适配(核心!)
model.resize_token_embeddings(len(new_vocabs)):
这是扩展词表的关键操作!原始BART模型的embedding层维度是“原始vocab_size × hidden_size”,扩展词表后必须调整embedding层的大小,否则模型无法识别新词表的token。- 同步更新配置文件:
bartconfig.vocab_size = len(new_vocabs),保证模型加载时使用新的词表大小。
三种处理词表未见过词的方法
在医学文本生成任务中,通用BART模型的词表必然无法覆盖所有医学编码/术语,针对这些“未见过的词”,有三种典型处理策略:
方法1:数字直接当id
核心逻辑
将医学编码数字(如“22”、“108”)直接作为模型的token id使用,无需映射到通用词表。
实现方式
- 无需修改原始BART词表,在数据预处理阶段:
# 伪代码:数字字符串→id def str2id(seq_str): return [int(x) for x in seq_str.split()] - 模型输入时直接使用这些数字作为token id,跳过Tokenizer的
convert_tokens_to_ids步骤。
优缺点
| 优点 | 缺点 |
|---|---|
| 实现最简单,无需扩展词表 | 1. id可能超出模型原始vocab_size,导致embedding层索引越界 2. 数字id无语义关联(如“22”和“23”无语义相似性) 3. 无法利用预训练模型的embedding权重 |
方法2:直接加字(扩展通用BART词表)
核心逻辑
保留通用BART词表,将医学编码作为“新token”添加到词表末尾,复用预训练模型的大部分embedding权重。
实现方式
对应pro_vocab.py中注释掉的代码:
# 加载通用BART的原始词表
tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
vocabs = tokenizer.get_vocab() # 获取模型原始词表
new_vocabs = list(vocabs.keys())
# 遍历自定义医学编码,补充到原始词表
count = 0
for v in vocab:
if v not in vocabs:
count += 1
new_vocabs.append(v)
# 保存扩展后的词表,更新模型embedding
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:
for v in new_vocabs:
f.write(f"{v}\n")
model.resize_token_embeddings(len(new_vocabs))
优缺点
| 优点 | 缺点 |
|---|---|
| 1. 保留通用词表的预训练权重 2. 医学编码作为独立token,id可控 |
1. 词表冗余(包含大量无用的通用词汇) 2. 医学编码的embedding是随机初始化的,无预训练语义 3. 词表体积大,增加模型显存占用 |
方法3:重新制作词表(本项目最终选型)
核心逻辑
抛弃通用BART词表,完全基于项目数据构建“医学编码专用词表”,仅保留必要的特殊token([PAD]、[CLS]、[EOS]等)。
实现方式
对应pro_vocab.py中的核心逻辑:
# 直接使用从数据中提取的医学编码作为新词表
new_vocabs = vocab
# 仅插入必要的特殊token
vocab.insert(0,"[PAD]")
vocab.insert(100,"[UNK]")
vocab.insert(101,"[CLS]")
# ... 其他特殊token
# 覆盖原始词表,更新模型
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:
for v in new_vocabs:
f.write(f"{v}\n")
model.resize_token_embeddings(len(new_vocabs))
bartconfig.vocab_size = len(new_vocabs)
选型理由(核心!)
本项目最终选择该方法,核心原因是:
- 场景适配性:项目输入/输出均为“数字编码序列”,无通用自然语言词汇,保留通用词表无意义;
- 效率优化:专用词表体积小(仅包含医学编码+必要特殊token),减少embedding层参数数量,降低显存占用;
- 可控性强:特殊token id可手动指定,完全匹配
config.py中的参数(如sos_id=101、eos_id=105),避免id冲突; - 训练效率:模型只需学习医学编码的映射关系,无需适配通用词汇,训练速度更快、收敛更稳定。
优缺点
| 优点 | 缺点 |
|---|---|
| 1. 词表轻量化,适配医学编码场景 2. 特殊token id完全可控 3. 训练效率高、显存占用低 |
1. 完全抛弃通用预训练embedding权重,需要从头微调 2. 词表仅适用于当前项目,无通用性 |
词表定制化的核心价值与注意事项
核心价值
- 适配医学场景:将通用BART模型改造为能处理“医学编码数字”的专用模型。
- 保证模型兼容性:通过
resize_token_embeddings和配置更新,确保模型能正常加载和运行。 - 词表可控性:从数据中自动提取token,避免手动维护词表的繁琐和错误。
关键注意事项
- 特殊token id对齐:
pro_vocab.py中插入的特殊token id(如[CLS]=101)必须和config.py中的参数(sos_id=101)完全一致,否则模型会出现输入/输出错位。 - 模型配置同步:
更新词表后必须同步修改BartConfig中的vocab_size,否则加载模型时会报“embedding维度不匹配”错误。 - 词表保存路径:
新的vocab.txt必须保存到模型目录(args.pre_model_path),确保AutoTokenizer能正确加载。
词表定制化核心总结
- 处理词表未见过词的三种方法各有优劣,选型需结合场景:通用文本生成选“直接加字”,专用编码生成选“重新制作词表”,快速验证选“数字直接当id”。
- 本项目选择“重新制作词表”的核心原因是:输入/输出均为医学编码,通用词表无价值,专用词表更高效、可控。
resize_token_embeddings和vocab_size配置更新是扩展/重构词表的两个关键操作,缺一不可。- 特殊token id必须和
config.py严格对齐,这是模型正常运行的核心前提。
【数据预处理–进阶】MLM任务适配与数据标准化
predata.py 是项目中预训练阶段的数据处理核心脚本,主要实现两大功能:① 原始CSV数据的加载与标准化;② 掩码语言模型(MLM)任务的数据构建(适配preModel类的预训练需求)。
predata.py核心功能定位
该脚本是为preModel(掩码语言模型)服务的,核心解决两个问题:
- 从CSV文件中加载训练/验证/测试数据,统一序列长度(padding/truncation);
- 实现MLM任务的核心逻辑(随机掩码、标签构建、批量拼接),适配预训练需求。
完整代码逐段解析
基础导入与全局配置
import json
import glob
import cv2
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import BertTokenizer
import numpy as np
from PIL import Image
import random
import time
import csv
import traceback
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from transformers import AutoTokenizer
# 设备自动检测
device = 'cuda' if torch.cuda.is_available() else 'cpu'
- 核心依赖:
torch.utils.data.Dataset(自定义数据集基类)、AutoTokenizer(词表映射)、csv(数据读取); - 全局设备配置:保证数据处理后能直接适配模型运行设备。
原始数据加载:loadData
def loadData(path):
# 定义数据文件路径
train_data_path = path+"/pro_train_data.csv"
val_data_path = path + "/pro_val_data.csv"
test_data_path = path + "/preliminary_a_test.csv"
path_list = [train_data_path, val_data_path,test_data_path]
all_data = []
# 遍历所有数据文件
for index,path in enumerate(path_list):
with open(path,"r") as f:
csv_data = csv.reader(f)
for i in csv_data:
if len(i)==0:# 过滤空行
break
if len(i)==3:# 训练集/验证集:包含id+input+target
id, input, target=i
input=input.split(' ') # 拆分编码序列为列表
target=target.split(' ')
else:# 测试集:仅id+input,target设为-1
id, input,target=i[0], i[1], -1
input=input.split(' ')
# 收集数据:训练集保留input+target,验证/测试集仅保留input
if index == 0:
all_data.append(input)
all_data.append(target)
else:
all_data.append(input) #验证和测试
return all_data
核心逻辑:
- 统一读取训练/验证/测试集的CSV文件,拆分编码序列为列表格式;
- 区分数据集类型:训练集保留输入+输出序列(用于MLM预训练),验证/测试集仅保留输入序列;
- 过滤空行:避免数据格式错误导致后续处理异常。
基础数据集类:PreTrainDataset
class PreTrainDataset(Dataset):
def __init__(self, data_file, input_l, output_l, sos_id=1, eos_id=2, pad_id=0):
with open(data_file, 'r') as fp:
reader = csv.reader(fp)
self.samples = [row for row in reader]
self.input_l = input_l # 输入序列最大长度
self.output_l = output_l # 输出序列最大长度
self.sos_id = sos_id # 开始token id
self.pad_id = pad_id # 填充token id
self.eos_id = eos_id # 结束token id
def __len__(self):
return len(self.samples)
def _try_getitem(self, idx):
# 处理输入序列:转换为数字+padding到固定长度
source = [int(x) for x in self.samples[idx][1].split()]
if len(source)<self.input_l:
source.extend([self.pad_id] * (self.input_l-len(source)))
# 测试集:仅返回输入序列
if len(self.samples[idx])<3:
return np.array(source)[:self.input_l]
# 训练/验证集:处理输出序列(添加SOS/EOS+padding)
target = [self.sos_id] + [int(x) for x in self.samples[idx][2].split()] + [self.eos_id]
if len(target)<self.output_l:
target.extend([self.pad_id] * (self.output_l-len(target)))
# 截断到最大长度,返回numpy数组
return np.array(source)[:self.input_l], np.array(target)[:self.output_l]
核心作用:
- 实现PyTorch Dataset的基础接口(
__init__/__len__/_try_getitem); - 标准化序列长度:通过padding将输入/输出序列统一到
input_l/output_l; - 序列装饰:为输出序列添加SOS(开头)和EOS(结尾)标记,符合Seq2Seq模型的输入要求。
通用工具函数:paddingList & truncate
def paddingList(ls:list,val,returnTensor=False):
"""
批量padding函数:将列表中的每个子列表padding到相同长度
:param ls: 二维列表(批量序列)
:param val: padding值(通常是pad_id)
:param returnTensor: 是否返回Tensor
:return: 标准化后的序列
"""
ls=ls[:]# 避免修改原列表
maxLen=max([len(i) for i in ls])
for i in range(len(ls)):
ls[i]=ls[i]+[val]*(maxLen-len(ls[i]))
return torch.tensor(ls,device=device) if returnTensor else ls
def truncate(a:list,b:list,maxLen):
"""
序列截断函数:将两个序列的总长度控制在maxLen内(预留3个位置给特殊token)
:param a: 第一个序列
:param b: 第二个序列
:param maxLen: 总最大长度
:return: 截断后的a和b
"""
maxLen-=3# 预留CLS/SEP/SEP的位置
assert maxLen>=0
len2=maxLen//2# 均分长度(奇数时左边更长)
len1=maxLen-len2
# 四种截断场景的处理
if len(a)+len(b)>maxLen:
if len(a)<=len1 and len(b)>len2:
b=b[:maxLen-len(a)]
elif len(a)>len1 and len(b)<=len2:
a=a[:maxLen-len(b)]
elif len(a)>len1 and len(b)>len2:
a=a[:len1]
b=b[:len2]
return a,b
核心价值:
paddingList:批量处理序列padding,是DataLoader中collate_fn的基础;truncate:避免序列过长导致显存溢出,适配MLM任务的序列长度限制。
MLM任务核心数据集:MLM_Data
class MLM_Data(Dataset):
"""适配掩码语言模型(MLM)的数据集类"""
def __init__(self, data, args):
super().__init__()
self.data=data # 原始token序列列表
self.maxLen= args.input_l-3 # 预留CLS/SEP/SEP的位置
self.tk=AutoTokenizer.from_pretrained(args.pre_model_path) # 加载自定义词表
self.spNum=len(self.tk.all_special_tokens) # 特殊token数量
self.tkNum=self.tk.vocab_size # 词表总大小
def __len__(self):
return len(self.data)
def random_mask(self, text_ids):
"""
MLM核心逻辑:随机掩码token,生成输入和标签
:param text_ids: 原始token id序列
:return: input_ids(掩码后的序列)、output_ids(掩码位置的真实标签)
"""
input_ids, output_ids = [], []
rands = np.random.random(len(text_ids)) # 生成随机数(0-1)
idx=0
# 控制掩码粒度:支持1/2/3-gram掩码
while idx<len(rands):
if rands[idx]<0.15:# 15%的概率掩码当前token
# 随机选择n-gram掩码(1-gram占70%,2-gram20%,3-gram10%)
ngram=np.random.choice([1,2,3], p=[0.7,0.2,0.1])
# 短序列限制gram大小,避免过度掩码
if ngram==3 and len(rands)<7:
ngram=2
if ngram==2 and len(rands)<4:
ngram=1
# 扩展掩码范围到n-gram
L=idx+1
R=idx+ngram
while L<R and L<len(rands):
rands[L]=np.random.random()*0.15# 强制掩码n-gram内的token
L+=1
idx=R
if idx<len(rands):
rands[idx]=1# 禁止掩码片段的下一个token被掩码
idx+=1
# 生成掩码后的输入和标签
for r, i in zip(rands, text_ids):
if r < 0.15 * 0.8:
# 80%概率:替换为MASK token,标签为原token
input_ids.append(self.tk.mask_token_id)
output_ids.append(i)
elif r < 0.15 * 0.9:
# 10%概率:保留原token,标签为原token(自预测)
input_ids.append(i)
output_ids.append(i)
elif r < 0.15:
# 10%概率:替换为随机token,标签为原token
input_ids.append(np.random.randint(self.spNum,self.tkNum))
output_ids.append(i)
else:
# 85%概率:不掩码,标签设为-100(CrossEntropyLoss会忽略)
input_ids.append(i)
output_ids.append(-100)
return input_ids, output_ids
def __getitem__(self, item):
"""单样本处理:token转换→掩码→添加特殊token"""
text= self.data[item]
text_ids = self.tk.convert_tokens_to_ids(text) # token→id
text_ids, out_ids = self.random_mask(text_ids) # 随机掩码
# 添加CLS(开头)和SEP(结尾)特殊token
input_ids = [self.tk.cls_token_id] + text_ids + [self.tk.sep_token_id]
token_type_ids=[ 0 ]*(len(text_ids)+2) # 单句,token_type_ids全0
labels = [-100] + out_ids + [-100] # 标签:忽略CLS/SEP位置
assert len(input_ids)==len(token_type_ids)==len(labels)
return {'input_ids':input_ids,'token_type_ids':token_type_ids,'labels':labels}
@classmethod
def collate(cls,batch):
"""批量拼接函数:padding到相同长度,生成attention_mask"""
input_ids=[i['input_ids'] for i in batch]
token_type_ids=[i['token_type_ids'] for i in batch]
labels=[i['labels'] for i in batch]
# 批量padding
input_ids=paddingList(input_ids,0,returnTensor=True)
token_type_ids=paddingList(token_type_ids,0,returnTensor=True)
labels=paddingList(labels,-100,returnTensor=True)
# 生成attention_mask(0为padding,1为有效token)
attention_mask=(input_ids!=0)
return {'input_ids':input_ids,'token_type_ids':token_type_ids
,'attention_mask':attention_mask,'labels':labels}
核心解析:
random_mask函数(MLM核心)
遵循BERT的MLM掩码规则,保证预训练效果:
| 掩码规则 | 概率 | 输入处理 | 标签处理 |
|---|---|---|---|
| 替换为MASK | 80% | input_ids = MASK_ID | output_ids = 原token id |
| 保留原token | 10% | input_ids = 原token id | output_ids = 原token id |
| 替换为随机token | 10% | input_ids = 随机id | output_ids = 原token id |
| 不掩码 | 85% | input_ids = 原token id | output_ids = -100(忽略) |
__getitem__函数
- 完成“token→id→掩码→添加特殊token”的完整流程;
- 为序列添加CLS(开头)和SEP(结尾),符合Transformer模型的输入格式;
- 标签处理:CLS/SEP位置设为-100,避免计算损失。
collate函数
- 批量padding:将不同长度的序列统一到批次最大长度;
- 生成attention_mask:告诉模型哪些位置是有效token,哪些是padding;
- 输出格式适配
preModel的forward函数输入(input_ids/attention_mask/labels)。
核心关联说明
在解析preModel类时,需关联predata.py的MLM_Data类:
preModel的forward函数接收
input_ids/attention_mask/labels格式的输入,这些输入正是由predata.py中的MLM_Data类生成的——MLM_Data实现了MLM任务的核心数据处理逻辑,是preModel预训练的基础。
核心总结
predata.py是预训练阶段的核心数据处理脚本,主要为preModel(MLM模型)提供标准化的输入数据;- 核心亮点是
random_mask函数,严格遵循BERT的MLM掩码规则,保证预训练效果; - 实现了完整的“原始数据→标准化序列→MLM掩码→批量拼接”流程,适配Transformer模型的输入要求;
- 工具函数(paddingList/truncate)是处理变长序列的通用方案,可复用于其他NLP任务。
【模型设计】基于 BART 的医学文本生成模型:从加载到定制化
数据预处理完成后,核心任务就是构建模型。本项目中 model_utils/models.py 封装了两个核心模型类:myModel(核心生成模型)和 preModel(预训练辅助模型),其中 myModel 是实现“CT编码→医学报告”生成的核心。
模型整体设计思路
本项目选择 BARTForConditionalGeneration 作为基础模型,原因是:
- BART是经典的Encoder-Decoder架构,天然适配Seq2Seq生成任务
BartForConditionalGeneration内置了生成逻辑(generate方法),无需手动实现自回归生成- 支持beam search等高级生成策略,能生成更流畅、准确的医学文本
myModel类:核心生成模型解析
这是整个项目的核心模型类,封装了BART模型的加载、训练和推理逻辑。
完整代码解析
from transformers import AutoTokenizer, BartForConditionalGeneration, AutoModelForMaskedLM
import torch.nn as nn
import torch
class myModel(nn.Module):
def __init__(self, args):
super(myModel, self).__init__()
# 1. 加载预训练的BART生成模型
self.model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)
# 2. 加载对应的Tokenizer(用于后续解码)
self.tokenizer = AutoTokenizer.from_pretrained(args.pre_model_path)
# 3. 初始化关键参数(从外部配置传入)
self.pad_id = args.pad_id # 输入序列填充id
self.tgt_pad_id = args.tgt_pad_id # 输出序列填充id
self.max_l = args.output_l # 生成序列最大长度
self.beam = args.beam # beam search的beam数
self.length_penalty = args.length_penalty # 长度惩罚系数
self.no_repeat = args.no_repeat # 禁止重复n-gram的长度
self.device = args.device # 运行设备(cpu/gpu)
def build_bart_inputs(self, input, tgt=None):
"""生成注意力掩码(mask):告诉模型哪些位置是有效数据,哪些是padding"""
# 输入mask:True表示有效token,False表示padding
input_mask = (input != self.pad_id)
if tgt == None:
return input_mask, None
else:
# 输出mask:同理处理目标序列
tgt_mask = (tgt != self.tgt_pad_id)
return input_mask, tgt_mask
def forward(self, inputs, tgts=None):
"""
核心前向传播函数:区分训练和推理两种模式
- 训练模式:tgts不为None,返回logits用于计算损失
- 推理模式:tgts为None,调用generate方法生成文本
"""
# 1. 构建注意力掩码
input_mask, tgt_mask = self.build_bart_inputs(inputs, tgts)
# 2. 推理模式(生成医学报告)
if tgts == None:
return self.model.generate(
inputs,
max_length=self.max_l, # 生成序列最大长度
attention_mask=input_mask, # 输入注意力掩码
min_length=2, # 生成序列最小长度(避免空输出)
num_beams=self.beam, # beam search数量
length_penalty=self.length_penalty, # 长度惩罚(避免过短/过长)
no_repeat_ngram_size=self.no_repeat, # 禁止重复n-gram(避免废话)
decoder_start_token_id=102 # decoder起始token(BART默认的SOS)
# early_stopping=True, # 可选:遇到EOS提前停止
)
# 3. 训练模式(计算损失)
outputs = self.model(
input_ids=inputs, # 输入序列
attention_mask=input_mask, # 输入mask
decoder_input_ids=tgts, # 解码器输入(目标序列)
decoder_attention_mask=tgt_mask # 解码器mask
)
return outputs.logits # 返回预测的logits,用于计算交叉熵损失
关键模块拆解
1. 模型初始化(init)
- 核心选择:
BartForConditionalGeneration是专门用于条件生成任务的BART变体,内置了:- Encoder:编码输入的CT影像编码序列
- Decoder:自回归生成医学报告编码序列
- 生成头:将Decoder的输出转换为词汇表概率分布
- 参数初始化:所有关键参数从外部配置(args)传入,保证模型的灵活性。
2. 注意力掩码构建(build_bart_inputs)
- 核心作用:Transformer模型需要知道哪些位置是真实数据(
pad_id以外的数字),哪些是填充的pad_id - 实现逻辑:
input != self.pad_id返回布尔数组,True表示有效token,False表示padding - 为什么需要mask:避免模型将padding部分纳入计算,导致训练偏差
3. 前向传播(forward):训练/推理双模式
这是模型最核心的部分,区分了两种运行模式:
| 模式 | 触发条件 | 核心操作 | 返回值 | 应用场景 |
|---|---|---|---|---|
| 训练 | tgts≠None | 调用model(),并行计算 | logits(预测概率) | 模型训练,计算损失 |
| 推理 | tgts=None | 调用model.generate(),串行生成 | 生成的序列id | 模型预测,生成报告 |
推理模式关键参数说明(beam search)
医学文本生成对准确性要求极高,因此使用beam search而非随机采样:
num_beams=args.beam:保留top-k个最优候选序列,平衡效果和效率(通常设为4/8)length_penalty:惩罚过长序列,公式: s c o r e = s c o r e ( l e n g t h ) α score = \frac{score}{(length)^\alpha} score=(length)αscore,α>1鼓励长序列,α<1鼓励短序列no_repeat_ngram_size:禁止重复n个连续token(如设为2,避免“异常异常”这类重复)decoder_start_token_id=102:BART默认的解码器起始token(<s>)
训练模式关键说明
- 训练时采用并行计算(Teacher Forcing):直接将完整的目标序列输入解码器,一次性计算所有位置的预测值
- 返回
logits:形状为[batch_size, seq_len, vocab_size],后续通过交叉熵损失计算预测误差
preModel类:预训练辅助模型
class preModel(nn.Module):
def __init__(self, args):
super(preModel, self).__init__()
# 加载掩码语言模型(用于预训练/特征提取)
self.model = AutoModelForMaskedLM.from_pretrained(args.pre_model_path)
print(self.model)
def forward(self, inputs, tgts=None):
# 适配Masked LM的输入格式
input_ids, attention_mask, labels = inputs["input_ids"], inputs["attention_mask"], inputs["labels"]
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
return outputs.loss
作用解析
AutoModelForMaskedLM:掩码语言模型(如BERT/BART的MLM模式),用于:- 对医学文本进行预训练(如果有未标注的医学语料)
- 提取文本特征,辅助主模型训练
- 本项目中该类未实际使用(主模型直接用
BartForConditionalGeneration),属于预留的预训练模块
模型工作流程可视化
模型篇核心总结
myModel是核心:基于BartForConditionalGeneration封装,实现了训练(并行计算)和推理(beam search生成)双模式。- 注意力掩码是关键:通过
build_bart_inputs生成mask,避免padding干扰模型计算。 - 生成策略适配医疗场景:使用beam search+长度惩罚+禁止重复n-gram,保证生成报告的准确性和流畅性。
preModel是辅助:基于掩码语言模型,用于医学文本的预训练,提升模型的医学领域适配能力。
【训练】BART 模型的微调实战:从训练循环到效果评估
模型和数据都准备好后,核心就是模型训练(微调) ——让预训练的BART模型学习“CT影像编码→医学报告”的映射关系。finetune.py 是整个项目的训练入口,包含了完整的训练循环、验证逻辑和模型保存策略。
训练脚本整体架构
核心代码逐段解析
环境配置与工具导入
import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args # 配置解析
from model_utils.data import create_dataloaders # 数据加载
from model_utils.models import myModel # 核心模型
from model_utils.score import CiderD, CE # 评估指标/损失函数
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer,array2str # 工具函数
from torch.cuda.amp import autocast as ac # 混合精度训练(本脚本未启用)
from tqdm import tqdm as tqdm # 进度条
# 指定GPU设备
os.environ['CUDA_VISIBLE_DEVICES']='0'
关键说明:
os.environ['CUDA_VISIBLE_DEVICES']='0':指定使用第0块GPU(单机单卡训练)- 导入的工具函数各司其职:
setup_device:设置训练设备(CPU/GPU)setup_seed:设置随机种子(保证结果可复现)setup_logging:配置日志输出build_optimizer:构建优化器和学习率调度器
验证函数:validate
这是训练过程中评估模型效果的核心函数,使用CIDEr-D指标(文本生成任务的经典评估指标)评估生成报告的质量。
def validate(model, loader, args, output_file=None, beam=1, n=-1):
res, gts = [], {} # res:模型生成结果;gts:真实标签
tot = 0 # 样本计数
# 遍历验证集
for (source, targets) in tqdm(loader):
if n>0 and tot>n: # 可选:只验证前n个样本(调试用)
break
# 1. 数据移到GPU
source = source.cuda()
# 2. 模型推理:生成报告编码序列(tgts=None,触发generate模式)
pred = model(source[:, :args.input_l])
# 3. 转换为numpy数组(便于后续处理)
pred = pred.cpu().detach().numpy()
# 4. 整理生成结果和真实标签
for i in range(pred.shape[0]):
# array2str:将数字序列转换为字符串(适配CIDEr-D计算)
res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})
gts[tot] = [array2str(targets[i][1:], args)] # 去掉target的SOS标记
tot += 1
# 5. 计算CIDEr-D分数(越高表示生成结果越接近真实标签)
CiderD_scorer = CiderD(df='corpus', sigma=15)
cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)
return cider_score
核心知识点:
- CIDEr-D指标:全称Consensus-Based Image Description Evaluation,原本用于图像描述评估,适配文本生成任务的核心逻辑是“计算生成文本与参考文本的相似度”,分数越高效果越好。
- array2str:工具函数,将数字编码序列转换为空格分隔的字符串(如
[22,12,38]→"22 12 38"),是计算CIDEr-D的前置处理。 - targets[:, 1:]:去掉目标序列的第一个token(SOS标记),只保留有效内容。
训练主函数:train_and_validate
这是整个训练流程的核心,包含“加载数据→初始化模型→训练循环→验证→保存模型”全流程。
def train_and_validate(args):
# 1. 加载训练/验证数据加载器
train_dataloader, val_dataloader = create_dataloaders(args)
# 2. 初始化模型
model = myModel(args)
# 3. 加载预训练权重(可选)
use_pre = True
if use_pre:
print('use_pre')
checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')
# 加载权重(strict=True:严格匹配参数名和形状)
new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)
# 4. 构建优化器和学习率调度器
optimizer, scheduler = build_optimizer(args, model)
# 5. 模型移到指定设备(GPU)
model = model.to(args.device)
# 6. 模型设为训练模式
model.train()
# 7. 初始化最优分数(用于保存最优模型)
step = 0
best_score = args.best_score # 初始最优CIDEr-D分数
# 8. 开始训练循环
for epoch in range(args.max_epochs):
# 遍历训练集批次
for (source, targets) in tqdm(train_dataloader):
# 数据移到GPU
source = source.cuda()
targets = targets.cuda()
# 确保模型在训练模式
model.train()
# 前向传播:训练模式(tgts≠None,返回logits)
pred = model(source[:, :args.input_l], targets[:, :args.output_l])
# 计算损失:交叉熵损失(CE)
# pred[:, :-1]:去掉最后一个预测值
# targets[:, 1:]:去掉第一个SOS标记(对齐预测和标签)
loss = CE(pred[:, :-1], targets[:, 1:])
loss = loss.mean() # 批次损失均值
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 清空梯度
model.zero_grad()
# 更新学习率
scheduler.step()
# 步数计数
step += 1
# 每1个epoch验证一次
if epoch % 1 == 0:
# 模型设为评估模式
model.eval()
# 禁止梯度计算(加速验证)
with torch.no_grad():
cider_score = validate(model, val_dataloader, args)
# 记录日志
logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")
# 保存最优模型
if cider_score >= best_score:
best_score = cider_score
torch.save(
{'epoch': epoch, 'model_state_dict': model.state_dict()},
f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin'
)
# 恢复训练模式
model.train()
核心逻辑拆解:
训练模式的前向传播
pred = model(source, targets):触发myModel的训练模式,返回logits(形状:[batch_size, seq_len, vocab_size])- 损失计算的关键对齐:
这样保证“预测第i个token”对应“标签第i个token”,符合自回归生成的逻辑。pred: [SOS, token1, token2, token3, EOS] → 取[:, :-1] → [SOS, token1, token2, token3] targets: [SOS, token1, token2, token3, EOS] → 取[:, 1:] → [token1, token2, token3, EOS]
损失函数:CE(交叉熵)
- 文本生成任务的核心损失函数,计算预测概率分布与真实标签的差距。
loss.mean():对批次内所有样本的损失取均值,作为当前批次的损失。
模型保存策略
- 只保存CIDEr-D分数最高的模型(避免保存效果差的模型,节省存储空间)。
- 保存内容:
epoch(训练轮数) +model_state_dict(模型参数),便于后续恢复训练/推理。
主函数:main
脚本的入口函数,负责初始化环境和启动训练。
def main():
# 1. 解析命令行/配置文件参数
args = parse_args()
# 2. 初始化日志、设备、随机种子
setup_logging()
setup_device(args)
setup_seed(args)
# 3. 创建模型保存目录(不存在则创建)
os.makedirs(args.savedmodel_path, exist_ok=True)
# 4. 打印训练参数(便于日志排查)
logging.info("Training/evaluation parameters: %s", args)
# 5. 启动训练
train_and_validate(args)
if __name__ == '__main__':
main()
训练关键细节补充
训练模式 vs 验证模式
| 阶段 | 模型模式 | 梯度计算 | 核心操作 |
|---|---|---|---|
| 训练 | model.train() | 开启(loss.backward()) | 反向传播更新参数 |
| 验证 | model.eval() | 关闭(torch.no_grad()) | 推理生成,计算评估指标 |
学习率调度器(scheduler)
scheduler.step():每步更新学习率(常见策略:线性衰减、余弦退火等)- 作用:训练后期降低学习率,让模型收敛更稳定。
混合精度训练(本脚本未启用)
- 代码中导入了
torch.cuda.amp.autocast,但未实际使用。 - 启用方式:在
pred = model(...)外层加with ac():,可减少显存占用,加速训练。
训练篇核心总结
- 训练流程遵循“初始化→训练循环→验证→保存最优模型”的经典范式,核心是每轮训练后用CIDEr-D评估效果。
- 损失计算的关键是预测序列和标签序列的对齐(去掉pred最后一位,去掉target第一位)。
- 验证阶段必须关闭梯度计算(
torch.no_grad())和设置model.eval(),避免影响模型参数和加速推理。 - 模型保存策略:只保存CIDEr-D分数最高的模型,是文本生成任务的通用最佳实践。
【推理】从训练好的模型到最终报告:医学文本生成的落地实践
训练完成后,最终的目标是用训练好的模型解决实际问题——对新的CT影像编码数据,生成对应的医学报告结论。inference.py 是整个项目的推理入口,实现了“加载模型→处理测试数据→生成报告→保存结果”的完整流程。
推理脚本核心逻辑
完整代码逐段解析
导入依赖与工具函数
from tqdm import tqdm # 进度条(显示推理进度)
import csv # 保存结果到CSV文件
from model_utils.utils import to_device, array2str # 设备迁移/序列转文本
from model_utils.models import myModel # 核心生成模型
from model_utils.data import create_dataloaders # 测试数据加载
import torch # PyTorch核心
from model_utils.config import parse_args # 配置解析
关键工具函数说明:
to_device:将数据迁移到指定设备(GPU/CPU),简化设备管理array2str:将模型生成的数字编码序列转换为空格分隔的字符串(如[22,12,38]→22 12 38),是连接“模型输出”和“可读文本”的关键
核心推理函数:inference
def inference(args):
# 1. 加载测试数据加载器(test=True:只加载测试集,无目标序列)
test_loader = create_dataloaders(args, test=True)
# 2. 初始化模型(和训练时的模型结构一致)
model = myModel(args)
print(args.ckpt_file) # 打印加载的模型权重文件路径(便于排查)
# 3. 加载训练好的模型权重
checkpoint = torch.load(args.ckpt_file, map_location='cpu') # 先加载到CPU,避免GPU显存问题
# strict=False:允许参数不完全匹配(适配微调后的模型结构)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
# 4. 模型移到GPU(推理加速)
model.to('cuda:0')
# 5. 模型设为评估模式(关键!禁用Dropout/BatchNorm等训练特有的层)
model.eval()
# 6. 打开输出文件,准备保存结果
fp = open(args.test_output_csv, 'w', newline='')
writer = csv.writer(fp)
# 7. 初始化样本计数
tot = 0
# 8. 遍历测试集,逐批次生成报告
for source in tqdm(test_loader):
# 8.1 将数据移到GPU
source = to_device(source, 'cuda:0')
# 8.2 模型推理:生成报告编码序列(tgts=None,触发generate模式)
pred = model(source)
# 8.3 转换为numpy数组(便于后续处理)
pred = pred.cpu().numpy()
# 8.4 遍历批次内的每个样本,保存结果
for i in range(pred.shape[0]):
# pred[i][2:]:去掉前两个token(SOS和冗余标记,根据实际编码规则调整)
# array2str:将数字序列转为字符串
writer.writerow([tot, array2str(pred[i][2:], args)])
tot += 1 # 样本计数+1
# 9. 关闭文件(避免数据丢失)
fp.close()
脚本入口:main函数
if __name__ == '__main__':
# 1. 解析配置参数(模型路径、数据路径、生成参数等)
args = parse_args()
# 2. 启动推理
inference(args)
推理关键细节拆解
模型加载的核心要点
| 步骤 | 操作 | 目的 |
|---|---|---|
torch.load(..., map_location='cpu') |
先加载到CPU | 避免直接加载到GPU导致显存溢出(尤其是大模型) |
strict=False |
非严格加载权重 | 适配微调后模型结构的微小变化(如新增层/删除层) |
model.to('cuda:0') |
移到GPU | 推理速度比CPU快10-100倍,是生产环境的必选项 |
model.eval() |
评估模式 | 禁用Dropout(防止随机失活导致结果不稳定)、固定BatchNorm均值/方差 |
推理模式的核心触发
pred = model(source):调用myModel的forward方法时,tgts=None,触发model.generate()逻辑,生成医学报告编码序列。- 生成过程使用训练时配置的beam search策略(beam数、长度惩罚等),保证生成结果的准确性。
结果处理与保存
pred[i][2:]:去掉前两个token(通常是SOS标记和冗余的起始符),只保留有效编码序列。writer.writerow([tot, array2str(...)]):按“样本编号,生成的编码序列”格式保存,和测试集输入格式一致,便于后续解码为自然语言报告。newline='':避免CSV文件出现空行(Windows系统的常见坑)。
推理 vs 训练的核心差异
| 维度 | 训练 | 推理 |
|---|---|---|
| 模型模式 | model.train() |
model.eval() |
| 梯度计算 | 开启(loss.backward()) |
关闭(默认,无需手动设置) |
| 输入数据 | source + targets |
仅source |
| 模型输出 | logits(用于计算损失) |
生成的序列id(最终结果) |
| 核心操作 | 反向传播更新参数 | 自回归生成序列 |
推理结果的后续处理(补充)
模型生成的是数字编码序列,要得到可读的医学报告,还需要最后一步:
# 伪代码:编码序列→自然语言报告
# 1. 加载词表(数字→医学术语的映射)
vocab = {}
with open('vocab.txt', 'r') as f:
for line in f:
idx, term = line.strip().split('\t')
vocab[int(idx)] = term
# 2. 解码生成的序列
def decode_sequence(seq_str):
# 将字符串转为数字列表
seq = [int(x) for x in seq_str.split()]
# 数字→术语映射,过滤PAD/EOS标记
terms = [vocab[idx] for idx in seq if idx not in [args.pad_id, args.eos_id]]
# 拼接为自然语言报告
return ','.join(terms) + '。'
# 3. 对推理结果解码
with open(args.test_output_csv, 'r') as f_in, open('final_report.csv', 'w') as f_out:
reader = csv.reader(f_in)
writer = csv.writer(f_out)
for row in reader:
tot, seq_str = row
report = decode_sequence(seq_str)
writer.writerow([tot, report])
效果示例:
- 模型生成的编码序列:
22 12 38 41 17 81 10 - 解码后自然语言:
腹部CT平扫,未见明显异常密度影,建议定期随访。
推理篇核心总结
- 推理的核心流程是“加载模型→加载测试数据→生成编码序列→保存结果”,关键是设置
model.eval()和禁用梯度计算。 - 模型权重加载时先加载到CPU再移到GPU,是避免显存溢出的通用技巧。
array2str是连接模型输出和可读文本的关键,最终的编码序列还需通过词表解码为自然语言医学报告。- 推理阶段无需计算损失,核心目标是高效、稳定地生成准确的医学报告。
项目完整总结:从CT影像到医学报告的BART文本生成实践
项目整体流程回顾
核心技术要点
- 任务本质:序列到序列(Seq2Seq)文本生成,输入是CT影像编码序列,输出是医学报告编码序列。
- 核心模型:
BartForConditionalGeneration,适配Seq2Seq生成任务,支持beam search生成策略。 - 数据处理:编码化存储医学术语,通过SOS/EOS标记和Padding统一序列长度。
- 训练策略:基于CIDEr-D指标的效果评估,只保存最优模型,保证生成质量。
- 推理落地:评估模式+GPU加速,生成编码序列后解码为可读的医学报告。
项目价值与拓展方向
价值
- 提升医生撰写影像报告的效率,降低重复劳动。
- 标准化医学报告格式,提升基层医疗的诊断规范性。
- 为AI辅助医疗诊断提供了可落地的工程实践方案。
拓展方向
- 优化生成策略:引入对比学习、提示学习,提升生成报告的医学准确性。
- 多模态扩展:直接输入CT影像(图片),而非编码序列,实现端到端的报告生成。
- 交互优化:支持医生对生成的报告进行修改,模型学习人工修正的反馈,持续迭代。
【补充】项目核心工具模块解析:配置、通用工具与评估指标
在前几篇的核心流程解析之外,config.py、utils.py、score.py 这三个文件是整个项目的“基础设施”——它们分别负责参数配置、通用工具函数和损失/评估指标计算,是保证项目可运行、可调试、可评估的关键。
config.py:全项目参数统一管理
config.py 是整个项目的“参数中心”,通过 argparse 定义了所有可配置的参数,避免了硬编码,让项目的灵活性和可维护性大幅提升。
核心功能解析
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Baseline for Weixin Challenge 2022")
# 按功能模块分组定义参数,结构清晰
# 1. 基础配置(随机种子、dropout、设备等)
parser.add_argument("--seed", type=int, default=2025, help="random seed.")
parser.add_argument('--dropout', type=float, default=0.2, help='dropout ratio')
parser.add_argument('--device', default="cuda", type=str, help="device")
# 2. 生成策略配置(beam search相关)
parser.add_argument('--beam', default=5, type=int, help='beamnum?')
parser.add_argument('--length_penalty', default=1, type=float, help='length_penalty')
parser.add_argument('--no_repeat', default=4, type=int, help='no_repeat_ngram_size')
# 3. 数据配置(路径、批次大小等)
parser.add_argument('--data_path', type=str, default='data')
parser.add_argument('--batch_size', default=2, type=int, help="use for training duration per worker")
# 4. 模型保存配置(路径、最优分数等)
parser.add_argument('--savedmodel_path', type=str, default='save/pretrain')
parser.add_argument('--ckpt_file', type=str, default=R'F:\pycharm\beike\bart\save\pretrain\model_epoch_48_cider_score_0.0.bin')
# 5. 训练配置(学习率、轮数等)
parser.add_argument('--learning_rate', default=3e-5, type=float, help='initial learning rate')
parser.add_argument('--max_epochs', type=int, default=50, help='How many epochs')
# 6. BART模型专属配置(序列长度、特殊token id等)
parser.add_argument('--input_l', type=int, default=150, help="输入序列最大长度")
parser.add_argument('--output_l', type=int, default=80, help="输出序列最大长度")
parser.add_argument('--sos_id', type=int, default=101, help="开始token id")
parser.add_argument('--eos_id', type=int, default=105, help="结束token id")
parser.add_argument('--pad_id', type=int, default=0, help="填充token id")
return parser.parse_args()
关键价值
- 参数集中管理:所有参数在一个文件中定义,修改时无需在多个脚本中查找,降低维护成本。
- 默认值+命令行覆盖:每个参数都有合理默认值,也可通过命令行传入自定义值(如
python finetune.py --batch_size 8)。 - 注释清晰:每个参数都有明确的注释,新人接手时能快速理解参数含义。
- 适配多场景:参数覆盖“训练/验证/推理”全流程,一套配置适配所有脚本。
utils.py:通用工具函数合集
utils.py 封装了项目中所有通用的工具函数,避免重复编码,是提升代码复用性的核心。
核心函数分类解析
设备管理:to_device
def to_device(data, device):
if isinstance(data, torch.Tensor):
data = data.to(device)
elif isinstance(data, np.ndarray):
data = to_device(torch.from_numpy(data), device)
elif isinstance(data, tuple):
data = tuple(to_device(item, device) for item in data)
elif isinstance(data, list):
data = list(to_device(item, device) for item in data)
elif isinstance(data, dict):
data = dict((k, to_device(v, device)) for k, v in data.items())
else:
raise TypeError('Unsupported Datatype! Must be a Tensor/List/Tuple/Dict.', type(data), data)
return data
- 作用:统一处理不同类型数据的设备迁移(CPU→GPU/GPU→CPU),无需手动逐个转换。
- 适配性:支持Tensor、Numpy数组、列表、元组、字典等多种数据类型,覆盖项目所有数据场景。
环境初始化:setup_device/setup_seed/setup_logging
def setup_device(args):
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.n_gpu = torch.cuda.device_count()
def setup_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
def setup_logging():
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
return logger
- setup_device:自动检测GPU是否可用,设置全局设备参数,适配不同硬件环境。
- setup_seed:固定随机种子,保证实验结果可复现(深度学习调试的核心技巧)。
- setup_logging:配置日志格式,输出训练/推理过程中的关键信息(时间、日志级别、内容)。
优化器构建:build_optimizer
def build_optimizer(args, model):
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
num_training_steps=args.max_steps)
return optimizer, scheduler
- 核心选择:使用
AdamW(带权重衰减的Adam),是Transformer模型的标配优化器。 - 学习率调度:搭配线性预热调度器(warmup),避免训练初期学习率过大导致模型不收敛。
序列转文本:array2str
def array2str(arr, args):
tk = AutoTokenizer.from_pretrained(args.pre_model_path)
out = ''
for i in range(len(arr)):
if arr[i]==args.pad_id or arr[i]==args.eos_id:
break # 遇到PAD/EOS停止,只保留有效序列
if arr[i]==args.sos_id:
continue # 跳过SOS标记
out += tk.convert_ids_to_tokens([arr[i]])[0] + ' '
if len(out.strip())==0:
out = '0' # 空序列兜底
return out.strip()
- 作用:将模型生成的数字编码序列转换为空格分隔的字符串,是连接“模型输出”和“评估指标计算”的桥梁。
- 关键处理:自动过滤PAD/EOS/SOS等特殊标记,只保留有效编码,保证评估结果准确。
score.py:损失函数与评估指标
score.py 是文本生成任务的“效果评判标准”,包含核心损失函数(CE)和评估指标(CIDEr-D)。
交叉熵损失:CE
def CE(output, target):
'''
Output: (B,L,C)。未经过softmax的logits
Target: (B,L)
'''
output = output.reshape(-1, output.shape[-1]) # (*,C)
target = target.reshape(-1).long() # (*)
loss = nn.CrossEntropyLoss()(output, target)
return loss
- 输入形状:
- output:
[批次大小, 序列长度, 词汇表大小](模型输出的logits) - target:
[批次大小, 序列长度](真实标签)
- output:
- 核心处理:将二维序列展平为一维,统一计算交叉熵损失,是序列生成任务的标准做法。
- 默认行为:
CrossEntropyLoss默认对所有token的损失取均值,保证损失值的稳定性。
CIDEr-D评估指标
这是文本生成任务的核心评估指标,代码实现了完整的CIDEr-D计算逻辑,核心步骤:
- precook:将文本转换为n-gram(1-4gram)的词频统计。
- cook_refs/cook_test:分别处理参考文本和生成文本的n-gram统计。
- sim:计算生成文本和参考文本的余弦相似度,加入长度惩罚(高斯函数)。
- compute_score:汇总所有样本的相似度,计算最终的CIDEr-D分数(越高表示生成效果越好)。
核心价值
- 适配医学文本:CIDEr-D相比BLEU更适合长文本生成,能更好地评估医学报告的语义相似度。
- 长度惩罚:通过
sigma参数控制长度惩罚强度,避免生成过短/过长的报告。 - 可复用性:封装为
CiderD类,可直接调用compute_score计算分数,适配训练/验证/推理全流程。
补充模块核心总结
config.py:全项目参数的“统一入口”,通过分组定义+默认值+命令行覆盖,保证项目的灵活性和可维护性。utils.py:封装通用工具函数,覆盖设备管理、环境初始化、优化器构建、序列转换等场景,提升代码复用性。score.py:实现文本生成任务的核心损失函数(CE)和评估指标(CIDEr-D),是模型效果评判的核心依据。
项目完整技术栈总结
| 模块文件 | 核心功能 | 关键技术点 |
|---|---|---|
config.py |
参数配置 | argparse、参数分组、默认值设置 |
process_data.py |
数据划分 | Pandas、随机采样、数据保存 |
data.py |
数据加载 | PyTorch Dataset/DataLoader、序列标准化 |
models.py |
模型构建 | BARTForConditionalGeneration、训练/推理双模式 |
finetune.py |
模型训练 | 训练循环、验证逻辑、模型保存 |
inference.py |
模型推理 | 评估模式、beam search生成、结果保存 |
utils.py |
通用工具 | 设备迁移、随机种子、优化器构建、序列转换 |
score.py |
损失/评估 | 交叉熵损失、CIDEr-D指标、n-gram相似度 |
至此,整个BART医学文本生成项目的所有代码模块都已解析完毕。从参数配置到数据处理,从模型构建到训练推理,从损失计算到效果评估,形成了一套完整、可落地的工业级文本生成解决方案,希望能帮助你全面掌握Seq2Seq模型在医疗领域的应用实践!
项目整体脉络与数据/代码流通全解析
为了让你一眼看清项目全貌,本文将从核心目标、整体流程、代码流通路径、数据流转与形态变化四个维度,系统化梳理整个医学文本生成项目的逻辑。
项目核心目标
基于预训练的BART模型,实现「CT影像编码序列 → 医学报告编码序列」的生成任务,本质是序列到序列(Seq2Seq)的文本生成,核心适配医疗场景下的编码化文本生成需求。
项目整体流程(一站式流程图)
代码流通路径(按执行顺序)
| 执行阶段 | 核心脚本 | 功能作用 | 依赖模块 |
|---|---|---|---|
| 1. 词表定制 | pro_vocab.py | 从训练数据提取医学编码→构建专用词表→更新BART模型embedding | config.py(参数) |
| 2. 数据预处理 | process_data.py | 划分训练/验证集,生成pro_train_data.csv/pro_val_data.csv | - |
| 3. 预训练数据适配(可选) | predata.py | 加载原始CSV→生成MLM任务数据→标准化序列长度/掩码处理 | config.py、utils.py |
| 4. 模型预训练(可选) | models.py(preModel) | 基于MLM数据预训练模型,适配医学编码 | predata.py、config.py |
| 5. 模型微调 | finetune.py | 加载数据→初始化myModel→训练循环→验证(CIDEr-D)→保存最优模型 | models.py、data.py、score.py、utils.py、config.py |
| 6. 推理生成 | inference.py | 加载最优模型→处理测试数据→生成编码序列→保存结果 | models.py、data.py、utils.py、config.py |
| 辅助模块 | config.py | 全项目参数统一管理(路径/超参/特殊token id) | - |
| 辅助模块 | utils.py | 设备迁移/随机种子/优化器/序列转文本等通用工具 | - |
| 辅助模块 | score.py | 交叉熵损失(训练)、CIDEr-D指标(验证) | - |
数据流转与形态变化(核心重点)
数据是项目的核心,以下按“原始形态→最终形态”梳理完整的流转过程,标注每一步的形态变化和处理逻辑:
阶段1:原始数据(初始形态)
- 存储格式:CSV文件(train.csv/val.csv/test.csv)
- 数据内容:
id,输入编码序列(空格分隔),输出编码序列(空格分隔)(测试集无输出序列) - 示例:
1001,14 22 38 41,17 81 10 99
阶段2:词表定制(数据→词表)
- 处理脚本:pro_vocab.py
- 输入:原始CSV中的编码序列(如“14”、“22”、“38”)
- 处理逻辑:
- 提取所有编码→统计频率→筛选有效编码;
- 插入特殊token([PAD]=0、[CLS]=101、[EOS]=104等);
- 生成专用vocab.txt,更新BART模型embedding层。
- 输出:自定义词表(vocab.txt),编码→id的映射关系(如“14”→50、“22”→51)。
阶段3:数据预处理(标准化)
-
处理脚本:data.py/predata.py
-
输入:原始CSV序列 + 自定义词表
-
处理逻辑(分两路):
路1(微调myModel):
- 字符串序列→数字id序列(如“14 22”→[50,51]);
- 输入序列:padding到input_l=150(不足补[PAD]=0,超长截断);
- 输出序列:添加[SOS]=101(开头)、[EOS]=104(结尾)→padding到output_l=80。
示例变化:
原始输出序列:17 81 10 → [101,52,53,54,104,0,0,...0](长度80)路2(预训练preModel):
- 编码序列→token id序列→随机掩码(MLM规则);
- 添加[CLS](开头)、[SEP](结尾)→padding到固定长度;
- 生成标签序列(掩码位置为原id,非掩码位置为-100)。
示例变化:
原始序列:14 22 38 → [101,50,103,52,102](22被掩码为103) 标签序列:[-100,50,51,-100,-100]
阶段4:模型训练(标准化数据→模型参数)
- 处理脚本:finetune.py
- 输入:标准化后的数字id序列(输入+输出)
- 处理逻辑:
- 输入序列喂入myModel→生成logits([batch, seq_len, vocab_size]);
- 计算交叉熵损失(CE):logits[:,:-1] vs 输出序列[:,1:](对齐预测和标签);
- 反向传播更新模型参数→每轮验证CIDEr-D分数→保存最优模型。
- 输出:训练好的模型权重(.bin文件),包含适配医学编码的参数。
阶段5:推理生成(模型→结果)
- 处理脚本:inference.py
- 输入:测试集标准化输入序列 + 最优模型权重
- 处理逻辑:
- 输入序列喂入模型→触发generate模式(beam=5搜索);
- 生成数字id序列→过滤[PAD]/[EOS]/[SOS]→转换为字符串序列;
- 保存到CSV文件(id,生成的编码序列)。
- 输出:推理结果CSV(如
1001,17 81 10 99)。
阶段6:最终解码(可选,项目外)
- 处理逻辑:生成的编码序列→通过词表/医学术语映射→自然语言报告;
- 示例变化:
生成编码序列:17 81 10 99 → 医学术语映射 → “腹部CT平扫,未见明显异常密度影,建议定期随访。”
核心总结
- 项目本质:基于BART的Seq2Seq生成,核心是将“CT影像编码”转为“医学报告编码”,适配医疗场景的专用化改造;
- 代码逻辑:先定制词表适配数据→再标准化数据→预训练/微调模型→最终推理生成结果,辅助模块(config/utils/score)贯穿全程;
- 数据核心:从“空格分隔的字符串序列”→“数字id序列(标准化长度)”→“模型生成的id序列”→“可读的医学报告”,每一步都围绕“编码→id→编码”的映射展开;
- 关键改造:抛弃通用词表,构建医学编码专用词表,是项目适配医疗场景的核心决策。
整个项目的逻辑闭环:数据定制词表→词表标准化数据→数据训练模型→模型生成新数据,所有代码都是为这个闭环服务,工具模块则是提升闭环效率和稳定性的支撑。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)