环境配置

https://github.com/wenet-e2e/wenet

git clone https://github.com/wenet-e2e/wenet.git # 克隆源码

AIShell 教程

我们提供了example/aishell/s0/run.sh关于 aishell-1 数据的配方

配方很简单,我们建议您手动逐个运行每个阶段并检查结果以了解整个过程。

cd example/aishell/s0
bash run.sh --stage -1 --stop-stage -1
bash run.sh --stage 0 --stop-stage 0
bash run.sh --stage 1 --stop-stage 1
bash run.sh --stage 2 --stop-stage 2
bash run.sh --stage 3 --stop-stage 3
bash run.sh --stage 4 --stop-stage 4
bash run.sh --stage 5 --stop-stage 5
bash run.sh --stage 6 --stop-stage 6

您也可以只运行整个脚本

bash run.sh --stage -1 --stop-stage 6

阶段-1:下载数据

此阶段将 aishell-1 数据下载到本地路径$data。这可能需要几个小时。

如果您已经下载了数据,请更改$data变量run.sh并从.--stage 0

阶段 0:准备训练数据

# 准备训练数据
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
  # Data preparation
  local/aishell_data_prep.sh ${data}/data_aishell/wav \
    ${data}/data_aishell/transcript
fi

在这个阶段,local/aishell_data_prep.sh将原始的 aishell-1 数据组织成两个文件:

  • wav.scp每行记录两个制表符分隔的列:wav_idwav_path
  • text 每行记录两个制表符分隔的列: wav_idtext_label

wav.scp

BAC009S0002W0122 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
BAC009S0002W0124 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
BAC009S0002W0125 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0125.wav
...

text

BAC009S0002W0122 而对楼市成交抑制作用最大的限购
BAC009S0002W0123 也成为地方政府的眼中钉
BAC009S0002W0124 自六月底呼和浩特市率先宣布取消限购后
BAC009S0002W0125 各地政府便纷纷跟进
...

如果您想使用自定义数据进行训练,只需将数据组织成两个文件wav.scptext,然后从.stage 1

第 1 阶段:提取可选 cmvn 特征

example/aishell/s0使用原始 wav 作为输入,使用TorchAudio在数据加载器中实时提取特征。所以在这一步中,我们只需将训练 wav.scp 和文本文件复制到raw_wav/train/目录中。

tools/compute_cmvn_stats.py用于提取全局 cmvn(倒谱均值和方差归一化)统计信息。这些统计数据将用于标准化声学特征。设置cmvn=false将跳过此步骤。

# 提取可选 cmvn 特征
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
  # remove the space between the text labels for Mandarin dataset
  for x in train dev test; do
    cp data/${x}/text data/${x}/text.org
    paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \
      <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
      > data/${x}/text
    rm data/${x}/text.org
  done

  tools/compute_cmvn_stats.py --num_workers 8 --train_config $train_config \
    --in_scp data/${train_set}/wav.scp \
    --out_cmvn data/$train_set/global_cmvn
fi

第 2 阶段:生成标签令牌字典

# 生成标签令牌字典
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
  echo "Make a dictionary"
  mkdir -p $(dirname $dict)
  echo "<blank> 0" > ${dict}  # 0 is for "blank" in CTC
  echo "<unk> 1"  >> ${dict}  # <unk> must be 1
  tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
    | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
    awk '{print $0 " " NR+1}' >> ${dict}
  num_token=$(cat $dict | wc -l)
  echo "<sos/eos> $num_token" >> $dict
fi

dict 是标签标记(我们为 Aishell-1 使用字符)和整数索引之间的映射。

一个示例字典如下

<blank> 0
<unk> 1
一 2
丁 3
...
龚 4230
龟 4231
<sos/eos> 4232
  • <blank>表示 CTC 的空白符号。
  • <unk>表示未知标记,任何词汇表外的标记都将映射到其中。
  • <sos/eos>表示用于基于注意力的编码器解码器训练的语音开始和语音结束符号,并且它们共享相同的 id。

第 3 阶段:准备 WeNet 数据格式

此阶段生成 WeNet 所需的格式文件data.list。中的每一行data.list都是 json 格式,其中包含以下字段。

  1. key: 话语的关键
  2. wav: 话语的音频文件路径
  3. txt:话语的标准化转录,转录将在训练阶段即时标记为模型单元。

这是一个示例data.list,请参阅生成的训练特征文件data/train/data.list

{"key": "BAC009S0002W0122", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0122.wav", "txt": "而对楼市成交抑制作用最大的限购"}
{"key": "BAC009S0002W0123", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0123.wav", "txt": "也成为地方政府的眼中钉"}
{"key": "BAC009S0002W0124", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0124.wav", "txt": "自六月底呼和浩特市率先宣布取消限购后"}

我们还设计了另一种data.list命名格式,shard用于大数据训练。如果您想在大数据集(超过 5k)上应用 WeNet,请参阅gigaspeech(10k 小时)或 wenetspeech(10k 小时),了解如何使用shard样式data.list

第 4 阶段:神经网络训练

NN 模型在此步骤中进行训练。

  • 多 GPU 模式

如果对多 GPU 使用 DDP 模式,我们建议使用dist_backend="nccl". 如果 NCCL 不起作用,请尝试使用gloo或使用torch==1.6.0 Set the GPU ids in CUDA_VISIBLE_DEVICES。例如,设置为使用卡 0,1,2,3,6,7。export CUDA_VISIBLE_DEVICES="0,1,2,3,6,7"

  • 恢复训练

如果您的实验在运行几个 epoch 后由于某些原因(例如 GPU 被其他人意外使用并且内存不足)而终止,您可以从检查点模型继续训练。只需找出完成的 epoch exp/your_exp/,设置 checkpoint=exp/your_exp/$n.pt并运行. 然后训练将从 $n+1.pt 继续run.sh --stage 4

  • 配置

神经网络结构、优化参数、损失参数和数据集的配置可以在 YAML 格式文件中设置。

conf/中,我们提供了几种模型,例如变压器和构象器。见conf/train_conformer.yaml参考。

  • 使用张量板

培训需要几个小时。实际时间取决于 GPU 卡的数量和类型。在一台 8 卡 2080 Ti 机器中,50 个 epoch 大约需要不到一天的时间。您可以使用 tensorboard 来监控损失。

tensorboard --logdir tensorboard/$your_exp_name/ --port 12598 --bind_all
dir=exp/conformer
cmvn_opts="--cmvn ${dir}/global_cmvn"
train_config=conf/train_conformer.yaml
data_type=raw
dict=data/dict/lang_char.txt
train_set=train

python3 train.py \
	--config $train_config \
	--data_type $data_type \
	--symbol_table $dict \
	--train_data data/$train_set/data.list \
	--model_dir $dir \
	--cv_data data/dev/data.list \
	--num_workers 1 \
	$cmvn_opts \
	--pin_memory

第 5 阶段:使用经过训练的模型识别 wav

需要文件:

dict词典文件:words.txt
model:final.pt
训练模型用的配置文件:/train.yaml
cmvn文件:在配置文件里面配置路径

待识别语言列表:data.list # 格式{key,wavscp,text}

解码过程

dir=/root/data/aizm/wenet/pre_modle/20210618_u2pp_conformer_exp

data_type=raw
dict=${dir}/words.txt
decode_checkpoint=${dir}/final.pt

decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0

test_dir=$dir/test_attention_rescoring
# 测试的语音内容{key,wavscp,text}
list_name=vad_test
data_list_dir=${list_name}.list
mkdir -p $test_dir
python recognize.py \
  --mode "attention_rescoring" \
  --config $dir/train.yaml \
  --data_type $data_type \
  --test_data ${data_list_dir} \
  --checkpoint $decode_checkpoint \
  --beam_size 10 \
  --batch_size 1 \
  --penalty 0.0 \
  --dict $dict \
  --ctc_weight $ctc_weight \
  --reverse_weight $reverse_weight \
  --result_file $test_dir/text_${list_name} \
  ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
python tools/compute-wer.py --char=1 --v=1 \
  data/test/text $test_dir/text > $test_dir/wer_${list_name}

这个阶段展示了如何将一组 wav 识别为文本。它还展示了如何进行模型平均。

  • 平均模型

如果${average_checkpoint}设置为true,则交叉验证集上的最佳${average_num}模型将被平均以生成增强模型并用于识别。

  • 解码

识别也称为解码或推理。NN的功能将应用于输入的声学特征序列以输出文本序列。

WeNet 中提供了四种解码方法:

  • ctc_greedy_search: encoder + CTC 贪婪搜索
  • ctc_prefix_beam_search:encoder + CTC 前缀波束搜索
  • attention`:encoder + attention-based decoder 解码
  • attention_rescoring:在基于注意力的解码器上使用编码器输出从 ctc 前缀波束搜索中重新评估 ctc 候选者。

一般来说,attention_rescoring 是最好的方法。有关这些算法的详细信息,请参阅U2 论文

--beam_size是一个可调参数,较大的光束尺寸可能会获得更好的结果,但也会导致更高的计算成本。

--batch_size“ctc_greedy_search”和“attention”解码模式可以大于1,“ctc_prefix_beam_search”和“attention_rescoring”解码模式必须为1。

  • wer评价

tools/compute-wer.py将计算结果的单词(或字符)错误率。如果您在没有任何更改的情况下运行配方,您可能会得到 WER ~= 5%。

第 6 阶段:导出训练好的模型

wenet/bin/export_jit.py将使用 Libtorch 导出经过训练的模型。导出的模型文件可轻松用于其他编程语言(如 C++)的推理。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐