一、背景

        接着上一篇代码解读 | Hybrid Transformers for Music Source Separation[02]文章,继续对Hybrid Transformer Demucs 代码进行解读。

        解读目标:明确数据从进入算法,在算法内部,以及在算法输出 这三个阶段中 数据的大小是如何变换的。例如:算法输入数据大小为[BatchSize,Channels,Length],算法内部的数据大小为[BatchSize,Channels,Freqency,Time],算法输出的数据大小为[BatchSize,Channels,Length]。

二、解读

        在htdemucs.py文件中编写测试代码print,把每个模块输出的数据大小打印出来。控制台打印的结果如下所示。例如:时域[1, 2, 258602] 可以理解成[BatchSize,Channels,Time];频域[1, 4, 2048, 253]可以理解成[BatchSize,Channels,Frequency,Time]

        再具体一些,不管是时域还是频域第一个维度数字1表示批大小(Batch_Size)第二个数字表示通道数(Channels)频域第三个数字可以理解成频域维度属性(对应模型图中的freq)频域第四个数字理解成时间维度属性(对应模型图中的time steps)时域的第三个数字可以理解成时间维度属性(对应模型图中的time steps)。根据这个先知条件,我们再和算法模型图中的Cin、Cout、xxx freq对比便一目了然了。

算法输入 torch.Size([1, 2, 258602])
频域输入(STFT输出) torch.Size([1, 4, 2048, 253])
频域输入(归一化) torch.Size([1, 4, 2048, 253])
时域输入(归一化) torch.Size([1, 2, 258602])
时域第1个编码层输出torch.Size([1, 48, 64651])
频域第1个编码层输出torch.Size([1, 48, 512, 253])
时域第2个编码层输出torch.Size([1, 96, 16163])
频域第2个编码层输出torch.Size([1, 96, 128, 253])
时域第3个编码层输出torch.Size([1, 192, 4041])
频域第3个编码层输出torch.Size([1, 192, 32, 253])
时域第4个编码层输出torch.Size([1, 384, 1011])
频域第4个编码层输出torch.Size([1, 384, 8, 253])
crosstransformer输出频域:torch.Size([1, 384, 8, 253]),时域:torch.Size([1, 384, 1011])
频域第1个解码层输出torch.Size([1, 192, 32, 253])
时域第1个解码层输出torch.Size([1, 192, 4041])
频域第2个解码层输出torch.Size([1, 96, 128, 253])
时域第2个解码层输出torch.Size([1, 96, 16163])
频域第3个解码层输出torch.Size([1, 48, 512, 253])
时域第3个解码层输出torch.Size([1, 48, 64651])
频域第4个解码层输出torch.Size([1, 16, 2048, 253])
时域第4个解码层输出torch.Size([1, 8, 258602])
频域输出(STFT输出) torch.Size([1, 4, 2, 258602])
时域输出(归一化) torch.Size([1, 4, 2, 258602])
算法输出(时域输出+频域输出) torch.Size([1, 4, 2, 258602])

     总结:打印出每个环节的数据大小,这样就能和算法模型图中的参数对应上。至于各个模块更具体的细节还需仔细阅读源码进行理解。


        感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:3 个月前 )
13493215 * remove v4.44 deprecations * PR comments * deprecations scheduled for v4.50 * hub version update * make fiuxp --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> 5 天前
8d50fda6 * Remove FSDP wrapping from sub-models. * solve conflict trainer.py * make fixup * add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size * put back extract_model_from_parallel * use transformers unwrap_model 5 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐