kaldi chain

kaldi chain

生成phone的语言模型lm,为了分母hclg lattice服务:

chain_lib.create_phone_lm(args.dir, args.tree_dir, run_opts, lm_opts=args.lm_opts)

exp7/chain/tdnn/log/make_phone_lm.log:

gunzip -c exp7/chain/tri5_tree/ali.1.gz exp7/chain/tri5_tree/ali.2.gz exp7/chain/tri5_tree/ali.3.gz exp7/chain/tri5_tree/ali.4.gz exp7/chain/tri5_tree/ali.5.gz exp7/chain/tri5_tree/ali.6.gz exp7/chain/tri5_tree/ali.7.gz exp7/chain/tri5_tree/ali.8.gz exp7/chain/tri5_tree/ali.9.gz exp7/chain/tri5_tree/ali.10.gz exp7/chain/tri5_tree/ali.11.gz exp7/chain/tri5_tree/ali.12.gz exp7/chain/tri5_tree/ali.13.gz exp7/chain/tri5_tree/ali.14.gz exp7/chain/tri5_tree/ali.15.gz exp7/chain/tri5_tree/ali.16.gz exp7/chain/tri5_tree/ali.17.gz exp7/chain/tri5_tree/ali.18.gz exp7/chain/tri5_tree/ali.19.gz exp7/chain/tri5_tree/ali.20.gz exp7/chain/tri5_tree/ali.21.gz exp7/chain/tri5_tree/ali.22.gz exp7/chain/tri5_tree/ali.23.gz exp7/chain/tri5_tree/ali.24.gz exp7/chain/tri5_tree/ali.25.gz exp7/chain/tri5_tree/ali.26.gz exp7/chain/tri5_tree/ali.27.gz exp7/chain/tri5_tree/ali.28.gz exp7/chain/tri5_tree/ali.29.gz exp7/chain/tri5_tree/ali.30.gz | ali-to-phones exp7/chain/tri5_tree/final.mdl ark:- ark:- | chain-est-phone-lm –num-extra-lm-states=2000 ark:- exp7/chain/tdnn/phone_lm.fst

生成分母lattice

exp7/chain/tdnn/log/make_den_fst.log:

chain-make-den-fst exp7/chain/tdnn/tree exp7/chain/tdnn/0.trans_mdl exp7/chain/tdnn/phone_lm.fst exp7/chain/tdnn/den.fst exp7/chain/tdnn/normalization.fst

den.fst :

  • 第一列 :输入结点(结点,也叫fst的状态数),G 3gram phone 状态(0-3475,一共3476个);
  • 第二列 :输出结点,G 3gram phone 状态(0-3474,一共3475个);(hmm_state????chain/chain-den-graph.cc)
  • 第三列:in label:H pdf 状态(1-2392,一共2392个,原本pdf是0-2391的(pdf包含forward-pdf 和self-loop-pdf))(下面的例子用的tdnn_3、tdnn_4的模型,pdf 1976个)
  • 第四列:out label:H pdf 状态
  • in label和out label相同
  • den.fst的结点是G空间,上面的分数是声学分数?语言分数?
1
2
3
4
5
6
7
8
9
10
3475    1516    1       1       8.70407677
3475 3011 1 1 15.0414305
3475 198 1 1 9.42041969
3475 318 1 1 9.59932804
3475 271 1 1 10.1355295
3475 99 1 1 8.69728088
3475 689 1 1 10.1195202
3475 3368 1 1 16.4325008
3475 2406 1 1 13.7536497
3475 914 1 1 10.7855644

计算loss脚本:

exp7/chain/tdnn/log/compute_prob_valid.1.log:

nnet3-chain-compute-prob –l2-regularize=0.0 –leaky-hmm-coefficient=0.1 –xent-regularize=0.1 exp7/chain/tdnn_3/final.mdl exp7/chain/tdnn_3/den.fst ‘ark,bg:nnet3-chain-copy-egs ark:exp7/chain/tdnn_4/egs/train_diagnostic.cegs ark:- | nnet3-chain-merge-egs –minibatch-size=1:64 ark:- ark:- |’

1
2
3
4
5
6
7
8
9
10
nnet3-chain-merge-egs --minibatch-size=1:64 ark:- ark:-
nnet3-chain-copy-egs ark:exp7/chain/tdnn/egs/valid_diagnostic.cegs ark:-
LOG (nnet3-chain-copy-egs[5.5.591~43-fa0934]:main():nnet3-chain-copy-egs.cc:395) Read 400 neural-network training examples, wrote 400
LOG (nnet3-chain-merge-egs[5.5.591~43-fa0934]:PrintSpecificStats():nnet-example-utils.cc:1143) Merged specific eg types as follows [format: <eg-size1>={<mb-size1>-><num-minibatches1>,<mbsize2>-><num-minibatches2>.../d=<num-discarded>},<egs-size2>={...},... (note,egs-size == number of input frames including context).
LOG (nnet3-chain-merge-egs[5.5.591~43-fa0934]:PrintSpecificStats():nnet-example-utils.cc:1173) 152={44->1,d=0},173={40->1,d=0},212={60->1,64->4,d=0}
LOG (nnet3-chain-merge-egs[5.5.591~43-fa0934]:PrintAggregateStats():nnet-example-utils.cc:1139) Processed 400 egs of avg. size 201.5 into 7 minibatches, discarding 0% of egs. Avg minibatch size was 57.14, #distinct types of egs/minibatches was 3/4
LOG (nnet3-chain-compute-prob[5.5.591~43-fa0934]:PrintTotalStats():nnet-chain-diagnostics.cc:194) Overall log-probability for 'output-xent' is -3.05123 per frame, over 18600 frames.
LOG (nnet3-chain-compute-prob[5.5.591~43-fa0934]:PrintTotalStats():nnet-chain-diagnostics.cc:194) Overall log-probability for 'output' is -0.328318 per frame, over 18600 frames.
LOG (nnet3-chain-compute-prob[5.5.591~43-fa0934]:~CachingOptimizingCompiler():nnet-optimize.cc:710) 0.323 seconds taken in nnet3 compilation total (breakdown: 0.139 compilation, 0.00658 optimization, 0.17 shortcut expansion, 0.00139 checking, 3.81e-05 computing indexes, 0.00598 misc.) + 0 I/O.

kaldi语音识别 chain模型的数据准备

normalization.fst在分母有限状态机den.fst的基础上,修改了初始概率和终止概率得到的。

查看 cegs内容:

nnet3-chain-copy-egs ark:exp7/chain/tdnn_4/egs/train_diagnostic.cegs ark,t:1

  • 由于chain模型采用跳帧策略,所以egs中存储的是三倍下采样后输出索引
  • 分子lattice。WFSA结构。用于计算MMI,第三列、第四列是pdf的id(1-1975)

image-20211012143307127

  • 强制对齐结果,用于计算交叉熵??实际用上了吗?代码中交叉熵也是用的前后向计算的,没有用这个

image-20211012143412126

计算loss代码:

nnet3-chain-compute-prob --l2-regularize=0.0 --leaky-hmm-coefficient=0.1 --xent-regularize=0.1 exp7/chain/tdnn_3/final.mdl exp7/chain/tdnn_3/den.fst ark:1

kaldi chain模型的序列鉴别性训练代码分析 jarvanWang博客园

Chain训练准则的计算 jarvanWang博客园

跳转路线:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
chainbin/nnet3-chain-compute-prob.cc 的 chain_prob_computer.Compute(example_reader.Value());

nnet3/nnet-chain-diagnostics.cc 的 this->ProcessOutputs(chain_eg, &computer);

nnet3/nnet-chain-diagnostics.cc 的 ComputeChainObjfAndDeriv(chain_config_, den_graph_,
sup.supervision, nnet_output,
&tot_like, &tot_l2_term, &tot_weight,
(nnet_config_.compute_deriv ? &nnet_output_deriv :
NULL), (use_xent ? &xent_deriv : NULL));
chain/chain-training.cc 的 ComputeChainObjfAndDerivE2e
DenominatorComputation denominator(opts, den_graph,supervision.num_sequences,nnet_output);//得到objf(loss)
//分母计算:
den_logprob_weighted = supervision.weight * denominator.Forward();
chain/chain-denominator.cc的BaseFloat DenominatorComputation::Forward() {
//分子计算:
numerator_ok = numerator.ForwardBackward(&num_logprob_weighted,
xent_output_deriv);

nnet_output:dnn输出

chain/chain-training.cc

1
2
Supervision &supervision
DenominatorGraph &den_graph

初始化:

NnetChainComputeProb chain_prob_computer(nnet_opts, chain_opts, den_fst,nnet);

chain/chain-supervision.h:可以看见类Supervision的详细定义与Supervision的成员解释

  • Supervision类:
    • label_dim:pdf数量
    • fst:按帧index排序

chain/chain-den-graph.h:可以看见类DenominatorGraph的成员定义

  • DenominatorGraph类:
    • ForwardTransitions
    • BackwardTransitions
    • Transitions

chain/chain-den-graph.cc:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
void DenominatorGraph::SetTransitions(const fst::StdVectorFst &fst,
int32 num_pdfs) {
int32 num_states = fst.NumStates();

std::vector<std::vector<DenominatorGraphTransition> >
transitions_out(num_states),
transitions_in(num_states);
for (int32 s = 0; s < num_states; s++) {
for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, s); !aiter.Done();
aiter.Next()) {
const fst::StdArc &arc = aiter.Value();
DenominatorGraphTransition transition;
transition.transition_prob = exp(-arc.weight.Value());
transition.pdf_id = arc.ilabel - 1;
transition.hmm_state = arc.nextstate;
KALDI_ASSERT(transition.pdf_id >= 0 && transition.pdf_id < num_pdfs);
transitions_out[s].push_back(transition);
// now the reverse transition.
transition.hmm_state = s;
transitions_in[arc.nextstate].push_back(transition);
}
}

std::vector<Int32Pair> forward_transitions(num_states);
std::vector<Int32Pair> backward_transitions(num_states);
std::vector<DenominatorGraphTransition> transitions;

for (int32 s = 0; s < num_states; s++) {
forward_transitions[s].first = static_cast<int32>(transitions.size());
transitions.insert(transitions.end(), transitions_out[s].begin(),
transitions_out[s].end());
forward_transitions[s].second = static_cast<int32>(transitions.size());
}
for (int32 s = 0; s < num_states; s++) {
backward_transitions[s].first = static_cast<int32>(transitions.size());
transitions.insert(transitions.end(), transitions_in[s].begin(),
transitions_in[s].end());
backward_transitions[s].second = static_cast<int32>(transitions.size());
}

forward_transitions_ = forward_transitions;
backward_transitions_ = backward_transitions;
transitions_ = transitions;
}
  • den.fst:(这里把fst权重转换为exp-)fstprint den.fst | awk -F'\t' '{$NF=exp(-$NF);print$0}' | tr ' ' '\t'

image-20211013153727170

image-20211013153656637

  • transitions_out :transitions_out索引是den.fst的第一列,因此表示这个结点状态辐射出多少条路径的意思

image-20211013153537767

  • transitions_in :transitions_in索引是den.fst的第二列,因此表示多少条路径到这个结点状态的意思

image-20211013143601942

  • transitions :按顺序把transitions_out内容和transitions_in 全放到transitions里
  • forward_transitions :transitions_out每个结点状态的初始数量和辐射出边的数量

image-20211013160335326

  • backward_transitions 内容:接完transitions_out全部后,transitions_in每个结点状态的初始数量和到这个结点边的数量

image-20211013160653298

只求loss时,也会用到前后向算法(chain/chain-training.cc:denominator.Forward())

$\alpha$ 的概率,依赖声学分数(nnet_output(pdf数*帧数))和语言分数,因为里面的状态是G空间的状态

chain/chain-denominator.cc:赋予$\alpha$ 值

1
2
3
4
5
6
7
8
9
10
BaseFloat DenominatorComputation::Forward() {
NVTX_RANGE(__func__);
AlphaFirstFrame(); //初始权重
AlphaDash(0);
for (int32 t = 1; t <= frames_per_sequence_; t++) {
AlphaGeneralFrame(t);
AlphaDash(t);
}
return ComputeTotLogLike();
}
1
2
3
4
5
frames_per_sequence_ = 30
den_graph_.NumStates() = 3221
alpha_(frames_per_sequence_ + 1,den_graph_.NumStates() * num_sequences_ + num_sequences_,kUndefined),

alpha_: num_cols_ = 3222, num_rows_ = 31, stride_ = 3224

alpha(t, i)对应前后向算法中的$\alpha$($\alpha_t(i)$定义:到时刻t,部分观测序列为$o_1,o_2….,o_t$且状态为i的概率为前向概率)

beta(t, i)对应前后向算法中的$\beta$($\beta_t(i)$定义:到时刻t,状态i,部分观测序列为$o_{t+1},o_{t+2},…,o_T$的概率为后向概率)

gamma(t, n) ($\gamma_t(i)$定义:给定模型和观测序列,在时刻t处于状态i的概率$\large{\gamma_t(i)=\frac{\alpha_t(i)\beta_t(i)}{\sum\limits_{j=1}^N\alpha_t(j)\beta_t(j)}}$)

chain/chain-datastruct.h:

后面计算alpha就是用到这个三元组,hmm_state知道哪个状态结点(要遍历的),pdf就知道对应哪个发射概率,transition_prob对应语言模型概率

1
2
3
4
5
6
struct DenominatorGraphTransition {
BaseFloat transition_prob; // language-model part of the probability (not
// in log)
int32_cuda pdf_id; // pdf-id on the transition.
int32_cuda hmm_state; // source, or destination, HMM state.
};

前后向算法里面的转移概率是语言模型概率

  • loss,用前后向算法得到loss值

分母前向alpha计算:chain/chain-denominator.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
int32 prob_stride = probs.Stride(); //yl:probs.Stride()是帧数
for (int32 h = 0; h < num_hmm_states; h++)
{
for (int32 s = 0; s < num_sequences; s++)
{
double this_tot_alpha = 0.0;
const DenominatorGraphTransition
*trans_iter = transitions + backward_transitions[h].first,
*trans_end = transitions + backward_transitions[h].second;
for (; trans_iter != trans_end; ++trans_iter)
{
BaseFloat transition_prob = trans_iter->transition_prob;
int32 pdf_id = trans_iter->pdf_id,
prev_hmm_state = trans_iter->hmm_state;
// prob_data 维度是pdf数,之前保存的矩阵是帧数行*pdf列,因此每次取当前帧的pdf值时,stride是帧数,每次跳这么多步,到达同一个时间的下一个pdf数值,取得该同一个时刻的发射概率
BaseFloat prob = prob_data[pdf_id * prob_stride + s],
this_prev_alpha = prev_alpha_dash[prev_hmm_state * num_sequences + s];
this_tot_alpha += this_prev_alpha * transition_prob * prob;
}
// Let arbitrary_scale be the inverse of the alpha-sum value that we
// store in the same place we'd store the alpha for the state numbered
// 'num_hmm_states'. We multiply this into all the
// transition-probabilities from the previous frame to this frame, in
// both the forward and backward passes, in order to keep the alphas in
// a good numeric range. This won't affect the posteriors, but when
// computing the total likelihood we'll need to compensate for it later
// on.
BaseFloat arbitrary_scale =
1.0 / prev_alpha_dash[num_hmm_states * num_sequences + s];
KALDI_ASSERT(this_tot_alpha - this_tot_alpha == 0);
this_alpha[h * num_sequences + s] = this_tot_alpha * arbitrary_scale;
}
}

分子

分子loss计算:chain/chain-generic-numerator.cc

知道了确定的pdf,三元组的时候就不用遍历所有pdf了,但也要遍历所有G空间状态。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
//chain/chain-training.cc:
GenericNumeratorComputation numerator(opts.numerator_opts,
supervision, nnet_output);
//跳到 chain/chain-generic-numerator.cc
GenericNumeratorComputation::GenericNumeratorComputation(
supervision_.e2e_fsts[i].NumStates() //fst就是分子lattice的结点数(状态数),对当前这条句子sequence能展开的图 有多少个状态

//跳回chain/chain-training.cc:
numerator_ok = numerator.ForwardBackward(&num_logprob_weighted,
xent_output_deriv);
//跳到chain/chain-generic-numerator.cc
bool GenericNumeratorComputation::ForwardBackward( // 计算out_transitions_、in_transitions_
//跳回chain/chain-generic-numerator.cc
// Forward part
AlphaFirstFrame(seq, &alpha[thread]);
partial_loglike_mt[thread] += AlphaRemainingFrames(seq, probs, &alpha[thread]);
// Backward part
BetaLastFrame(seq, alpha[thread], &beta[thread]);
BetaRemainingFrames(seq, probs, alpha[thread], &beta[thread], &derivs);

chain/chain-generic-numerator.cc:

  • out_transitions_:GenericNumeratorComputation的类成员,transition_prob是当前lattice里的权重减去该句子权重最大值,再取负号

image-20211020161517294

  • in_transitions_:GenericNumeratorComputation的类成员

image-20211020162405339

  • 计算$\alpha$,所有lattice上这个状态结点的概率 累加 (概率还包含状态转移和pdf发射)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
for (int t = 1; t <= num_frames; ++t)
{
const BaseFloat *probs_tm1 = probs.RowData(t - 1);
BaseFloat *alpha_t = alpha->RowData(t);
const BaseFloat *alpha_tm1 = alpha->RowData(t - 1);

for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++)
{
for (auto tr = in_transitions_[seq][h].begin();
tr != in_transitions_[seq][h].end(); ++tr)
{
BaseFloat transition_prob = tr->transition_prob;
int32 pdf_id = tr->pdf_id,
prev_hmm_state = tr->hmm_state;
BaseFloat prob = probs_tm1[pdf_id];
alpha_t[h] = LogAdd(alpha_t[h],
alpha_tm1[prev_hmm_state] + transition_prob + prob);
}
}
double sum = alpha_tm1[alpha->NumCols() - 1];
SubMatrix<BaseFloat> alpha_t_mat(*alpha, t, 1, 0,
alpha->NumCols() - 1);
alpha_t_mat.Add(-sum);
sum = alpha_t_mat.LogSumExp();

alpha_t[alpha->NumCols() - 1] = sum;
log_scale_product += sum;
}