Wenet脚本 解码

解码

decode_main

decode_main参数:

  • rescoring_weight:如果为0,不需要rescore;不为0的任意值,要做rescore
  • ctc_weight:rescore中,最终某条n-best分数 = score_aed + ctc_weight * score_ctc

binding

prefix beam search

参考 https://github.com/awni/speech

之前的笔记中解码代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Author: Awni Hannun


This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.


The algorithm is a prefix beam search for a model trained
with the CTC loss function.


For more details checkout either of these references:
https://distill.pub/2017/ctc/#inference
https://arxiv.org/abs/1408.2873


"""


import numpy as np
import math
import collections


NEG_INF = -float("inf")


def make_new_beam():
fn = lambda : (NEG_INF, NEG_INF)
return collections.defaultdict(fn)


def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp


def decode(probs, beam_size=3, blank=0):
"""
Performs inference for the given output probabilities.


Arguments:
probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.


Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
probs = np.log(probs)


# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
beam = [(tuple(), (0.0, NEG_INF))] # 一开始走blank, p_b = 1


for t in range(T): # Loop over time


# A default dictionary to store the next step candidates.
next_beam = make_new_beam() # 新建一个dict,dict里的每个元素初始都是(-INF,-INF)


for s in range(S): # Loop over vocab
p = probs[t, s]


# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, (p_b, p_nb) in beam: # Loop over beam


# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.
# n_p_b、n_p_nb:路径总概率,累加. n_p_b: new_p_b, new一条新创建的路径
if s == blank:
# 只更新n_p_b的概率,论文里remove blank,指的是没有连续的blank blank,也就是说不考虑n_prefix = prefix + (s,)当s为blank的情况
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p) # 合并前缀/规整前缀
next_beam[prefix] = (n_p_b, n_p_nb)
continue


# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
# 是否是aa还是ac,和前一个符号是否相等
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix] # 先提出来,更新后,再赋回去
if s != end_t:
# ac的情况
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# aa的情况 a_a 输出 aa
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)


# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = (n_p_b, n_p_nb)


# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
if s == end_t:
# aa的情况 aaa 输出a,这里把aaa情况合并到a情况(prefix=a)
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb) # next_beam[prefix]要更新


# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x : logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]


best = beam[0]
return best[0], -logsumexp(*best[1])


if __name__ == "__main__":
# np.random.seed(3)


# time = 50
# output_dim = 20

# probs = np.random.rand(time, output_dim)
# probs = probs / np.sum(probs, axis=1, keepdims=True)


# 0 3 3 0 3 4 => 3 3 4
probs=[[0.8,0.05,0.05,0.05,0.05],[0.05,0.05,0.05,0.8,0.05],[0.05,0.05,0.05,0.8,0.05],[0.8,0.05,0.05,0.05,0.05],[0.05,0.05,0.05,0.8,0.05],[0.05,0.05,0.05,0.05,0.8]]
probs = np.array(probs)
probs = probs / np.sum(probs, axis=1, keepdims=True)

labels, score = decode(probs)
print(labels)
print("Score {:.3f}".format(score))