WeKws

WeKws

https://github.com/wenet-e2e/wekws

Mining Effective Negative Training Samples for Keyword Spotting (github, paper)

Max-pooling Loss Training of Long Short-term Memory Networks for Small-footprint Keyword Spotting (paper)

A depthwise separable convolutional neural network for keyword spotting on an embedded system (github, paper)

Hello Edge: Keyword Spotting on Microcontrollers (github, paper)

An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling (github, paper)

代码结构梳理

1
2
3
4
5
6
train_dataset = Dataset(args.train_data, train_conf)
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)

其中,kws/dataset/dataset.py里的Dataset函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def Dataset(data_list_file, conf, partition=True):
""" Construct dataset from arguments

We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.

Args:
data_type(str): raw/shard
partition(bool): whether to do data partition in terms of rank
"""
lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
dataset = Processor(dataset, processor.parse_raw)
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)

resample_conf = conf.get('resample_conf', {})
dataset = Processor(dataset, processor.resample, **resample_conf)

speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = Processor(dataset, processor.speed_perturb)
feature_extraction_conf = conf.get('feature_extraction_conf', {})
if feature_extraction_conf['feature_type'] == 'mfcc':
dataset = Processor(dataset, processor.compute_mfcc,
**feature_extraction_conf)
elif feature_extraction_conf['feature_type'] == 'fbank':
dataset = Processor(dataset, processor.compute_fbank,
**feature_extraction_conf)
spec_aug = conf.get('spec_aug', True)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)

if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)

batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)
dataset = Processor(dataset, processor.padding)
return dataset

给dataset按if添加了很多项,写法比较规范

examples/hi_xiaowen/s0/kws/dataset/processor.py,有点看不懂这里是怎么jump的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

def shuffle(data, shuffle_size=1000):
""" Local shuffle the data

Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle

Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x


def batch(data, batch_size=16):
""" Static batch the data by `batch_size`

Args:
data: Iterable[{key, feat, label}]
batch_size: batch size

Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf

max pooling loss

max_pooling loss是取正样本某帧的最大正类概率值,让这帧概率越大越好,取负样本某帧的最小负类概率值,让这帧的概率越大越好

  • 二分类时:

    target = filler时:$loss=\min\limits_T(1-P_{keyword})$ (min pooling)

    target = keyword时:$loss=\max\limits_TP_{keyword}$ (max pooling)

  • 三分类时(两个keyword)

    target = filler时:$loss=\min\limits_T(1-P_{keyword1})+\min\limits_T(1-P_{keyword2})$ (min pooling)

    target = keyword1时:$loss=\max\limits_TP_{keyword1}+\min\limits_T(1-P_{keyword2})$ (max pooling)

    target = keyword2时:$loss=\max\limits_TP_{keyword2}+\min\limits_T(1-P_{keyword1})$ (max pooling)

  • 目标:$\max_Wloss$

  • 这里我一开始有个误区,$\min\limits_T(1-P_{keyword})$其实不等价为$\max\limits_T(P_{keyword}-1)$!!!而是$\min\limits_T(1-P_{keyword})=-\max\limits_T(P_{keyword}-1)$

  • 最小化loss(代码里取负号后是最小化loss,不取负号是最大化loss,我这里先不取负号进行解释),因此要最大化target=keyword时 $P_{keyword}$的概率,因为只要有一帧大于阈值就算唤醒,所以取max-pooling对应最大keyword概率帧的概率,同时也要最小化nonkeyword的概率,这里希望最难训练的一帧nonkeyword也要尽可能小,最难训练一帧对应的min-pooling的$1-P_{nonkeyword}$,使得$\min\limits_T(1-P_{nonkeyword})$尽可能大作为loss function,随着迭代该值能够越来越大,意味着最难训练的nonkeyword的概率越来越小

  • [==TODO==]稳定之后,尝试focal loss?尽可能让所有的keyword的概率都要大,试试$\min\limits_TP_{keyword}$使之尽可能大

    推理时只考虑keyword帧是否大于阈值(只要有一帧大于阈值就算唤醒)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def max_pooling_loss(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
''' Max-pooling loss
For keyword, select the frame with the highest posterior.
The keyword is triggered when any of the frames is triggered.
For none keyword, select the hardest frame, namely the frame
with lowest filler posterior(highest keyword posterior).
the keyword is not triggered when all frames are not triggered.

Attributes:
logits: (B, T, D), D is the number of keywords
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
'''
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)

target = target.cpu()
loss = 0.0
for i in range(num_utts):
for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
prob = logits[i, :, j]
m = mask[i].clone().detach()
m[:min_duration] = True
prob = prob.masked_fill(m, 0.0)
prob = torch.clamp(prob, 1e-8, 1.0)
max_prob = prob.max()
loss += -torch.log(max_prob)
else:
# For other keywords or filler, do min-polling
prob = 1 - logits[i, :, j]
prob = prob.masked_fill(mask[i], 1.0)
prob = torch.clamp(prob, 1e-8, 1.0)
min_prob = prob.min()
loss += -torch.log(min_prob)
loss = loss / num_utts
  • kws/bin/average_model.py:把最后保存的N个模型里面的参数求和取平均
  • kws/bin/score.py:计算声学模型输出,保存到文件中
  • kws/bin/compute_det.py:计算FRR/FAR:对于某个分类,看它的分数是否大于阈值,大于就唤醒,小于没唤醒;(而不是在不同分类之间比较大小,从而确定是哪个分类唤醒,这是因为这里的输出没有filler分类,只有keyword分类)

hi_xiaowen数据集替换为自己的数据集

。。。

生成data.list:

用shell直接从现有文件中生成了

1
2
3
4
5
6
7
awk '{print"{\"key\": \""$1"\","}' feats.scp > 1
awk '{print"\"txt\": "$2","}' text > 2
awk '{print"\"duration\": "$2","}' utt2dur > 3
awk '{print"\"wav\": \""$2"\"}"}' feats.scp > 4
paste -d ' ' 1 2 3 4 > data.list
rm 1 2 3 4

1
2
3
4
5
6
7
awk '{print"{\"key\": \""$1"\","}' feats_offline_cmvn.scp > 1
awk '{print"\"txt\": "$2","}' text > 2
awk '{print"\"duration\": "$2","}' utt2dur > 3
awk '{print"\"wav\": \""$2"\"}"}' feats_offline_cmvn.scp > 4
paste -d ' ' 1 2 3 4 > data_after_offline_cmvn.list
rm 1 2 3 4

不用代码统计global cmvn,用kaldi的

1
matrix-sum --binary=false scp:data/train_p400h_n4000h/cmvn.scp - > data/train_p400h_n4000h/global_cmvn.stats

训练

num_worker=1

输出模型步长

之前是一个epoch输出一个模型,现在改成1000次迭代输出一个模型

mdtc_small:

dilation:低层到高层的值逐渐增长 1,2,4,8

self.receptive_fields感受野(非kernel size)大小与dilation有关,这个变量是为了给卷积补零用的,统计一共需要多少补零的长度

如果causal=True,就不能卷(计算)当前时间帧后面帧的信息,只能给过去帧补零,以达到能够计算的长度

如果causal=False,前后帧补零(除以2)$\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor$

1
self.receptive_fields = dilation * (kernel_size - 1)

补零的不计算?:

1
2
3
4
5
if self.causal:
inputs = inputs[:, :, self.receptive_fields:]
else:
inputs = inputs[:, :, self.
half_receptive_fields:-self.half_receptive_fields]

[TODO] 补零,可以用复制代替??

preprocessor预测里后的特征分别经过stack_num个TCN Stack,得到stack_num个输出

结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
the number of model params: 33201
Receptive Fields: 184
KWSModel(
(global_cmvn): GlobalCMVN()
(preprocessing): NoSubsampling()
(backbone): MDTC(
(preprocessor): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(50, 50, kernel_size=(5,), stride=(1,), groups=50)
(bn): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(50, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(relu): ReLU()
(blocks): ModuleList(
(0): TCNStack(
(res_blocks): Sequential(
(0): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(1): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(2,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(2): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(4,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(3): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(8,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
)
)
(1): TCNStack(
(res_blocks): Sequential(
(0): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(1): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(2,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(2): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(4,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(3): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(8,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
)
)
(2): TCNStack(
(res_blocks): Sequential(
(0): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(1): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(2,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(2): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(4,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
(3): TCNBlock(
(conv1): DSDilatedConv1d(
(conv): Conv1d(32, 32, kernel_size=(5,), stride=(1,), dilation=(8,), groups=32)
(bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pointwise): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
)
(bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU()
(conv2): Conv1d(32, 32, kernel_size=(1,), stride=(1,))
(bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU()
)
)
)
)
)
(classifier): LinearClassifier(
(linear): Linear(in_features=32, out_features=1, bias=True)
(quant): QuantStub()
(dequant): DeQuantStub()
)
(activation): Sigmoid()
)

参数图:

image-20220112174511263

image-20220112174541188

修改代码

在log中把train的learning rate也打印出来

在kws/utils/executor.py中,修改logging代码为:

1
2
3
4
5
6
logging.debug(
'TRAIN Batch {}/{} loss {:.8f} acc {:.8f} lr {:.8f}'.format(
epoch, batch_idx, loss.item(), acc, optimizer.param_groups[0]['lr']))
# logging.debug(
# 'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(
# epoch, batch_idx, loss.item(), acc))

打印model结构图/可视化

法一:tensorboard[可视化不太明显]

在kws/model/kws_model.py中,添加代码

1
2
3
4
5
6
7
   kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation)
# 添加:
from tensorboardX import SummaryWriter
dummy_input = torch.rand(128,100,40)
with SummaryWriter(comment='KWSModel')as w:
w.add_graph(kws_model, (dummy_input,))

在路径runs/下,tensorboard打开

image-20220111101914418

可以展开看细节。但是这样可视化程度不是很强。[tensorboard可视化模型结构并不友好]

法二:torchviz

在kws/model/kws_model.py中,添加代码

1
2
3
4
5
6
7
8
9
10
from torchviz import make_dot
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation)
# 添加
x = torch.rand(128,100,40)
y=kws_model(x)
# g = make_dot(y)
g = make_dot(y, params=dict(list(kws_model.named_parameters()) + [('x', x)]))
g.view()
g.render('espnet_model', view=False)

image-20220111105457343

法三:tensorwatch

支持的网络不够多

在kws/model/kws_model.py中,添加代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorwatch as tw

# 这个是debug出来发现发现的
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
x = self.preprocessing(x)
# 添加一行unsqueeze,在第0维,扩充1维度,维度为1
x = torch.unsqueeze(x, dim=0)
x, _ = self.backbone(x)
# 添加一行squeeze,把维度为1的维度删掉
x = torch.squeeze(x, dim=0)
x = self.classifier(x)
x = self.activation(x)
return x


kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation)

# 添加
# 查看参数、计算量、flop等
# module name、input shape、output shape、parameters、infer memory、MAdd、Flops、MemRead、MemWrite、
a = tw.model_stats(kws_model, [100,50])
print(a)
# 保存网络结构
img = tw.draw_model(kws_model, [100,50]) # 有bug,原因是里头有量化.
img.save('./ds_tcn.jpg')
  • ds_tcn:以batch size=1,time=1s,40维特征为例的tensorwatch.model_stats

image-20220111141741028

法四:fvcore

csdn 详解Transformer中Self-Attention以及Multi-Head Attention

1
2
3
from fvcore.nn import FlopCountAnalysis
flops = FlopCountAnalysis(mdtc, (x,lengths))
print("Multi-Head Attention FLOPs:", flops.total())