语音标点恢复是指在语音识别(ASR)输出的无标点文本中,自动添加逗号、句号、问号等标点符号,以提升文本可读性和语义清晰度。

FireRedPunc

模型地址:https://www.modelscope.cn/models/xukaituo/FireRedPunc

FireRedPunc支持中英文标点预测,包括对英文大小写的后处理功能
支持的标点为逗号(,)句号(。)问号(?)感叹号(!)
其模型架构以BERT为基础,其F1 为 78.90%,优于 FunASR-Punc的62.77%。(官方说法)

使用方法

from fireredasr2s.fireredpunc.punc import FireRedPunc, FireRedPuncConfig

config = FireRedPuncConfig(use_gpu=True)
model = FireRedPunc.from_pretrained("xukaituo/FireRedPunc", config)

batch_text = ["你好世界", "Hello world"]
results = model.process(batch_text)

print(results)
# [{'punc_text': '你好世界。', 'origin_text': '你好世界'}, {'punc_text': 'Hello world!', 'origin_text': 'Hello world'}]

FunASR-Punc

模型地址:https://www.modelscope.cn/models/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch

FunASR-Punc支持中文通用标点预测
支持的标点为下划线(_)、逗号(,)句号(。)问号(?)顿号(、)

模型介绍

Controllable Time-delay Transformer是达摩院语音团队提出的高效后处理框架中的标点模块。本项目为中文通用标点模型,模型可以被应用于文本类输入的标点预测,也可应用于语音识别结果的后处理步骤,协助语音识别模块输出具有可读性的文本结果。

在这里插入图片描述

Controllable Time-delay Transformer 模型结构如上图所示,由 Embedding、Encoder 和 Predictor 三部分组成。Embedding 是词向量叠加位置向量。Encoder可以采用不同的网络结构,例如self-attention,conformer,SAN-M等。Predictor 预测每个token后的标点类型。

更详细的细节见:

使用方法

from funasr import AutoModel

model = AutoModel(model="ct-punc", model_revision="v2.0.4")

res = model.generate(input="那今天的会就到这里吧 happy new year 明年见", disable_pbar=True)
print(res)
#[{'key': 'rand_key_2yW4Acq9GFz6Y', 'text': '那今天的会就到这里吧,happy new year,明年见。', 'punc_array': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 3])}]

评测对比

测试数据集

IWSLT2012-zh

下载地址:https://github.com/jiangnanboy/punctuation_prediction

数据处理

def get_iwslt20212_zh_test_data(test_data_path):
    with open(test_data_path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    pattern = re.compile(r'\(.*?\)|\(.*?\)|\(.*?\)|\(.*?\)')

    data_list = []
    for line in lines:
    	#去除数据中括号带有的描述词,如(笑声)
        line_without_brackets = pattern.sub('', line.strip().replace(" ", ""))
        
        #去除非法标点
        reference = line_without_brackets.replace(".", "").replace("-", "").replace("—", "").replace(",", ",")
         #去除标点
        raw_text = reference.replace(",", "").replace("。", "").replace("?", "").replace("·", "").replace(":", "").replace("“", "").replace("”", "").replace("‘", "").replace("《", "").replace("》", "").replace(":", "")
        data = {
            "raw_text": raw_text,
            "reference": reference
        }
        data_list.append(data)

    print(f"Total test data: {len(data_list)}")
    return data_list

CDCPP

下载地址:https://github.com/NLPBLCU/Cross-Domain-Chinese-Punctuation-Prediction

数据处理

def get_CDCPP_test_data(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    
    offset = 0
    all_data_list = []
    data_list = []
    while offset < len(lines):
        line = lines[offset].strip()
        #空行分割
        if len(line) == 0:
        	#每2-3行为一组,其中第2行为黄金标准
            if len(data_list) >= 2:
            	#去掉中文文本中多余空格
                reference = data_list[1].replace(" ", "")
                #去掉标点
                raw_text = reference.replace(",", "").replace("。", "").replace("?", "").replace("!", "").replace("、", "").replace(",", "")
                data = {
                    "raw_text": raw_text,
                    "reference": reference
                }
                all_data_list.append(data)
                data_list = []
        else:
            data_list.append(line)
        offset += 1
    print(f"load {file_path}, Total test data: {len(all_data_list)}")
    return all_data_list

def main():
	test_data_dir = "./Cross-Domain-Chinese-Punctuation-Prediction-master"
	for name in ["QA__text", "weibo_long", "weibo_short"]:
    	punc_data_list = get_CDCPP_test_data(os.path.join(test_data_dir, f"{name}.txt"))

评价指标

F1分数(F1 Score)是精确率(Precision)和召回率(Recall)的调和平均数,其计算公式为

F1 = 2 × (Precision × Recall) / (Recall +Recall)

‌精确率指模型预测为正类的样本中,实际为正类的比例,计算公式为

Precision = TP / (TP + FP)

‌召回率指所有实际为正类的样本中,被模型正确预测为正类的比例,计算公式为

Recall = TP / (TP + FN)

代码实现

def evaluate_punctuation_recovery(result_list):
    stats_dict = {
        mark: {"tp": 0, "fp": 0, "fn": 0} for mark in ["all", ",", "。", "?", "!", "other"]
    }
    for id, data in enumerate(result_list):
        reference = data['reference']
        hypothesis = data['hypothesis']
        raw_data = data["raw_text"]
        
        ref_offset = 0
        hyp_offset = 0
        raw_offset = 0
        while ref_offset < len(reference) or hyp_offset < len(hypothesis) or raw_offset < len(raw_data):
            ref_word = reference[ref_offset] if ref_offset < len(reference) else ""
            hyp_word = hypothesis[hyp_offset] if hyp_offset < len(hypothesis) else ""
            raw_word = raw_data[raw_offset] if raw_offset < len(raw_data) else ""

            if ref_word == raw_word:
                # 场景1:参考=原始,预测=原始 → 无标点标注,无需统计,所有偏移量递增
                if hyp_word == raw_word or hyp_word.upper() == raw_word or hyp_word.lower() == raw_word:
                    ref_offset += 1
                    hyp_offset += 1
                    raw_offset += 1
                else:
                    # 场景2:参考=原始,但预测≠原始 → 预测错误(FP,假阳性)
                    if len(hyp_word) > 0:
                        mark_key = hyp_word if hyp_word in stats_dict else "other"
                        stats_dict[mark_key]["fp"] += 1
                        stats_dict["all"]["fp"] += 1
                    hyp_offset += 1
            # 场景3:预测=原始,但参考≠原始 → 漏标(FN,假阴性)
            elif hyp_word == raw_word or hyp_word.upper() == raw_word or hyp_word.lower() == raw_word:
                if len(ref_word) > 0:
                    mark_key = ref_word if ref_word in stats_dict else "other"
                    stats_dict[mark_key]["fn"] += 1
                    stats_dict["all"]["fn"] += 1
                ref_offset += 1
            # 场景4:其他情况 → 要么预测正确(TP),要么预测错误(FP)
            else:
                if len(hyp_word) > 0:
                    # 场景4.1:预测=参考(但都≠原始)→ 标注正确(TP,真阳性)
                    if hyp_word == ref_word:
                        mark_key = hyp_word if hyp_word in stats_dict else "other"
                        stats_dict[mark_key]["tp"] += 1
                        stats_dict["all"]["tp"] += 1
                        hyp_offset += 1
                        ref_offset += 1
                    else:
                        # 场景4.2:预测≠参考 → 预测错误(FP)
                        mark_key = hyp_word if hyp_word in stats_dict else "other"
                        stats_dict[mark_key]["fp"] += 1
                        stats_dict["all"]["fp"] += 1
                        hyp_offset += 1
                        if len(ref_word) > 0:
                            ref_offset += 1
                else:
                    # 场景4.3:预测为空 → 漏标(FN)
                    if len(ref_word) > 0:
                        mark_key = ref_word if ref_word in stats_dict else "other"
                        stats_dict[mark_key]["fn"] += 1
                        stats_dict["all"]["fn"] += 1
                        ref_offset += 1
                    else:
                        raw_offset += 1
                        
    def cal_f1(tp, fp, fn):
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        return {'precision': round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
    
    f1_results = {}
    for category, stats in stats_dict.items():
        f1_results[category] = cal_f1(stats['tp'], stats['fp'], stats['fn'])

    print(f'ALL\t\tF1:{f1_results["all"]["f1"]:.4f}\tPrecision:{f1_results["all"]["precision"]:.4f}\tRecall:{f1_results["all"]["recall"]:.4f}')
    print(f'Comma\t\tF1:{f1_results[","]["f1"]:.4f}\tPrecision:{f1_results[","]["precision"]:.4f}\tRecall:{f1_results[","]["recall"]:.4f}')
    print(f'Period\t\tF1:{f1_results["。"]["f1"]:.4f}\tPrecision:{f1_results["。"]["precision"]:.4f}\tRecall:{f1_results["。"]["recall"]:.4f}')
    print(f'Question\tF1:{f1_results["?"]["f1"]:.4f}\tPrecision:{f1_results["?"]["precision"]:.4f}\tRecall:{f1_results["?"]["recall"]:.4f}')
    print(f'exclamation\tF1:{f1_results["!"]["f1"]:.4f}\tPrecision:{f1_results["!"]["precision"]:.4f}\tRecall:{f1_results["!"]["recall"]:.4f}')

评测结果

IWSLT2012-zh测试对比结果

Method FirRedPunc Comma Period Question
F1 P R F1 P R F1 P R F1 P R
FireRedPunc 0.6428 0.5520 0.7693 0.6048 0.5003 0.7647 0.7462 0.6016 0.9822 0.8611 0.7750 0.9688
FunASR-Punc 0.5710 0.4698 0.7278 0.5425 0.4477 0.6882 0.7125 0.5616 0.9745 0.8936 0.8235 0.9767

CDCPP测试对比结果

Method FirRedPunc Comma Period Question Exclamation
F1 P R F1 P R F1 P R F1 P R F1 P R
FireRedPunc 0.7796 0.7178 0.8530 0.7538 0.6860 0.8365 0.8541 0.7730 0.9542 0.9037 0.8753 0.9339 0.5309 0.4640 0.6202
FunASR-Punc 0.6876 0.6202 0.7714 0.6796 0.6206 0.7510 0.7343 0.6191 0.9023 0.8812 0.8791 0.8833 - - -
CDCPP子数据集测试结果

FireRedPunc

DataSet ALL Comma Period Question Exclamation
F1 P R F1 P R F1 P R F1 P R F1 P R
QA 0.8866 0.9029 0.8709 0.8395 0.8730 0.8085 0.9677 0.9467 0.9897 0.9468 0.9468 0.9468 0.4800 0.3529 0.7500
weibo_long 0.7488 0.6632 0.8599 0.7394 0.6484 0.8601 0.8143 0.7186 0.9395 0.8451 0.7610 0.9502 0.5896 0.5034 0.7115
weibo_short 0.7228 0.6448 0.8223 0.7158 0.6329 0.8238 0.7795 0.6718 0.9283 0.8273 0.7845 0.8750 0.4649 0.4433 0.4886
ALL 0.7796 0.7178 0.8530 0.7538 0.6860 0.8365 0.8541 0.7730 0.9542 0.9037 0.8753 0.9339 0.5309 0.4640 0.6202

FunASR-Punc

DataSet ALL Comma Period Question Exclamation
F1 P R F1 P R F1 P R F1 P R F1 P R
QA__text 0.8312 0.7950 0.8708 0.7797 0.7480 0.8141 0.9127 0.8466 0.9900 0.9268 0.9184 0.9354 - - -
weibo_long 0.6340 0.5546 0.7399 0.6504 0.5793 0.7414 0.6680 0.5467 0.8585 0.8162 0.8251 0.8075 - - -
weibo_short 0.6311 0.5650 0.7146 0.6443 0.5884 0.7119 0.6649 0.5381 0.8698 0.7909 0.7970 0.7850 - - -
ALL 0.6876 0.6202 0.7714 0.6796 0.6206 0.7510 0.7343 0.6191 0.9023 0.8812 0.8791 0.8833 - - -
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐