Pruned RNN-T for fast, memory-efficient ASR training. Pruned RNN-T loss

Pruned RNN-T for fast, memory-efficient ASR training. Pruned RNN-T loss

Kuang, Fangjun, et al. “Pruned RNN-T for fast, memory-efficient ASR training.” arXiv preprint arXiv:2206.13236 (2022). citations:9

公众号 新一代Kaldi:多快好省的 RNN-T 训练

公众号 新一代Kaldi:Pruned RNN-T 何以又快又好

开源代码 rnnt_loss.py:https://github.com/k2-fsa/k2/blob/master/k2/python/k2/rnnt_loss.py

Bechmark 的开源代码 transducer-loss-benchmarking:https://github.com/csukuangfj/transducer-loss-benchmarking

  • 简介

新一代 Kaldi 针对 RNN-T 损失函数的创新性改进。原始的 RNN-T 损失函数在处理长文本或者长语音的时候显存占用较大,训练时间较长。对此,我们提出对 RNN-T 的 log-probability lattice 进行裁剪,能够在不损失训练精度的前提下大幅缩短 RNN-T 损失函数的计算时间。该工作已经被Interspeech2022 收录。

  • 解决什么问题

在 CTC 模型、RNN-T 模型和 attention-based 模型中,RNN-T 是最适合用于产线部署的流式解码模型。但 RNN-T 具有以下几个痛点:

  1. 训练时,与其他模型相比,占用的内存至少高一个数量级
  2. 解码时,很难通过以 batch 的方式实现并行解码
  • 用了什么方法

    1. 提出 pruned RNN-T,改变了 RNN-T 模型中最后一层的输出维度:从 (N, T, U, C) 变成了 (N, T, S, C),其中 $U \gg S$ 。矩阵变小了,内存使用就降低了。
    2. 矩阵变小了,Pruned RNN-T 进入 joiner 网络的是更小得多的 (N,T,S,V)向量 ,所以 joiner 网络里面的非线性层和 Linear 层的计算量大大减小,从而大大提高了计算速度。
    3. Pruned RNN-T 的快部分来自高效的代码实现
  • 效果如何

    具有以下特点:

    “多”

    • 可以用更大的 batch size
    • 可以用更大的 vocabulary size

    “快”

    • 目前为止,所有常用开源实现中,训练速度最快
    • 训练出来的模型,在不降低性能的前提下,解码速度更快

    “好”

    • 训练出来的模型,在 GigaSpeech 和 WenetSpeech 的 Leaderboard 上,排名第一

    “省”

    • 省内存
    • 目前为止,所有常用开源实现中,所需内存最少

针对问题1,训练部分,展开介绍。

存在问题:RNNT训练时占用内存高

RNN-T 模型最后一层的输出是一个 4-D 的 tensor,维度是 (N, T, U, C), 其中

  • N: batch size。数值大小: 一般是几十
  • T: encoder 的输出帧数。数值大小:一般是好几百(CTC network,encoder)
  • U: decoder 的输出帧数。数值大小:几十至上百(prediction network,decoder)
  • C: vocabulary size。数值大小:几百至上千

所以,RNN-T 训练时,所需的内存(显存)正比于 N, T , U, C 这 4 个数的乘积 NTUC。训练 CTC 或者 attention-based 模型时,所需的内存(显存)一般与 NTC 或者 NUC 成正比。相比较之下,RNN-T 模型的训练,对内存的要求高了一个数量级。

这样一个大的向量需要占据很大的显存,导致没法使用大的 batch size 来训练,另外,如此大的向量也造成 joiner 网络的计算量非常大,从而增加单次迭代的时间。

为了避免训练时出现 out-of-memory (OOM) 错误,通常的做法是:

  • 减少 N,使用一个小的 batch size
  • 减少 C,使用一个较小的 vocabulary size
  • 降低模型参数量

但是,使用小的 batch size 会增加模型训练所需的时间;而使用小的 vocabulary size, 可能会影响模型的性能。例如,若以单个汉字为建模单元,vocabulary size 一般是 4000 到 7000 之间。如果使用一个很小的 vocabulary size,那么对于 out-of-vocabulary (OOV) 这种问题,就会更加常见。而降低模型的参数量,也会影响模型的性能。)

那么如何在不降低模型性能的前提下,做到以下几点呢?

  • 降低训练时所需的内存
  • 降低训练所需的时间

相似工作

微软 Improving RNN Transducer Modeling for End-to-End Speech Recognition:通过移除一个 batch 中所需的 padding 来减少内存的占用量。作者对这个方法提供了一个开源的实现,链接如下: optimized_transducer。相关 benchmark 数据表明,这种方法在所有标准的 RNN-T 开源实现中,所需内存最低。

解决方法

提出 pruned RNN-T,pruned RNN-T 改变了 RNN-T 模型中最后一层的输出维度:

  • (N, T, U, C) 变成了 (N, T, S, C)

其中,S 是用户指定的一个参数。我们所做的实验中,一般选用 5;而 U 则一般是几十甚至上百。

结果就是,所需的内存与 NTSC 成正比,而不再是 NTUC

image-20230313144523612

图1(a)是针对标准的 RNN-T,图中所有的结点都参与了RNN-T loss 的计算。

作者提出一个很有意思的问题:是不是所有的结点都应该参与计算?然后做出论证,如图2所示。

图2 显示了图 1(a)中每个结点在训练时某一时刻的梯度。我们可以看到,==随着训练的进行,靠近对角线上的结点对计算起到的作用最大。也就是说,不同位置的结点,在训练中起到的作用不同。==

图1(b)则是针对 pruned RNN-T,图中只有部分结点参与了 RNN-T loss 的计算。

参与计算的结点数量越少,所需的计算量则越少、计算速度就越快,并且所需的内存也越少。

  • 哪些结点应该被选出来参与计算呢?
  • 又如何选择这些结点呢?

音频和文本的单调对应特性决定了 lattice 中的大多数节点对最终的 Loss 几乎没有贡献。

trivial joiner 平凡联合网络

为了确定剪裁的边界,我们提出了一个“平凡联合网络”trivial joiner)的概念,这个 trivial joiner 是 encoder 和 predictor 的简单相加,即 am + lm。使用这样一个简单的 joiner 网络是为了在不生成四维向量的情况下得到一个 lattice,以便在这个 lattice 上求得剪裁边界。下图是 Pruned RNN-T 计算的流程图,我们实际上计算了两次损失函数,一次是在上述的 trivial joiner 上,一次是在正常的包含非线性层的 joiner 上(下图中的 s_range 就是上面提到的 S)。

![640 (1)](640 (1).png)

在一个 lattice 中,每一个节点包含了两个概率,即 $y(t, u)$ 和 $\varnothing(t, u) $, $ y(t, u)$ 表示在第 $t$ 帧给定 $y_{0 . . u}$ 的情况下发射 $y_{u+1}$ 的对数概率, $\varnothing(t, u)$ 则代表在第 $t$ 帧给定 $y_{0 . . u}$ 的情况下 发射 blank 的对数概率。由于 trivial joiner 是个简单的相加,所以我们不需要在相加之 后的向量中来获取这两个概率,只需分别在 $a m$ 和 $l m$ 中获得这两个概率,然后将 $a m$ 和 $l m$ 中得到的概率分别加起来就行。获取概率的操作就是个简单的查询,在代码中使用 torc h. gather 来实现,这个过程和乘法分配律非常相似。

注:两个 shape 不一样的向量相加得先统一 shape,即 logit = am.unsqueeze(2) + lm.unsqueeze(1),所以如果相加之后再获取概率,我们就不得不生成一个四维向量。

剪裁边界的确定

TODO。。。。。。

实验及结果

pruned RNN-T 在以下几方面的结果:

  • 训练速度
  • 训练所需内存
  • 在LibriSpeech test-clean 测试集上的 WER 及 RTF

在进行 benchmark 时,我们采用 LibriSpeech test-clean 测试集来生成 RNN-T 模型训练时所需的维度信息, 而不是针对特定的维度进行 benchmark。这样可以考虑每个 batch 中 padding 所造成的影响,尽量还原真实的应用场景。

我们设置了两种 bechmark 模式:

  • (1) 以随机的方式组成 batch。batch size 为 30。
  • (2) 按照样本时长进行排序的方式组成 batch。每个 batch 中最多包含 1万 帧特征。

在论文中,我们对比了 pruned RNN-T 和 常用的开源 RNN-T loss 实现的性能。结果如表 1 和 表 2 所示。我们可以看出,不管是在训练时间还是在内存使用量上, pruned RNN-T 与其他实现相比,都有很大的优势。