triplet loss keyword spotting 代码

triplet loss keyword spotting 代码

以 google speech commands为例

Vygon, Roman, and Nikolay Mikhaylovskiy. “Learning efficient representations for keyword spotting with triplet loss.” International Conference on Speech and Computer. Springer, Cham, 2021.ciations:12

github: Learning Efficient Representations for Keyword Spotting with Triplet Loss

github:https://github.com/NVIDIA/NeMo/blob/v0.10.1/examples/asr/notebooks/3_Speech_Commands_using_NeMo.ipynb

https://www.codeleading.com/article/61624664033/

https://bindog.github.io/blog/2019/10/23/why-triplet-loss-works/

github人脸检测:https://github.com/kuaikuaikim/DFace

facenet:https://github.com/davidsandberg/facenet

/home/data/yelong/triplet_loss_kws/loss/utils.py

1
2
3
4
5
6
7
8
9
10
def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=random_hard_negative,
cpu=cpu)
class FunctionNegativeTripletSelector(TripletSelector):
"""
For each positive pair, takes the hardest negative sample (with the greatest triplet loss value) to create a triplet
Margin should match the margin used in triplet loss.
negative_selection_fn should take array of loss_values for a given anchor-positive pair and all negative samples
and return a negative index for that pair
"""

找triplet三元组:

方法:找出所有同类a,p对,而a,n对根据不同策略选出不同a,n,我这里用的随机,就是在所有d(a,p)-d(a,n)+margin大于0的组合中,随机选一个a,n作为三元组(a,p,n)的n index

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

import torch
import numpy as np
from itertools import combinations
def pdist(vectors):
distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
dim=1).view(-1, 1)
return distance_matrix

class TripletSelector:
"""
Implementation should return indices of anchors, positive and negative samples
return np array of shape [N_triplets x 3]
"""

def __init__(self):
pass

def get_triplets(self, embeddings, labels):
raise NotImplementedError

def random_hard_negative(loss_values):
hard_negatives = np.where(loss_values > 0)[0] #index
return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None


class FunctionNegativeTripletSelector(TripletSelector):
"""
For each positive pair, takes the hardest negative sample (with the greatest triplet loss value) to create a triplet
Margin should match the margin used in triplet loss.
negative_selection_fn should take array of loss_values for a given anchor-positive pair and all negative samples
and return a negative index for that pair
"""

def __init__(self, margin, negative_selection_fn, cpu=True):
super(FunctionNegativeTripletSelector, self).__init__()
self.cpu = cpu
self.margin = margin
self.negative_selection_fn = negative_selection_fn

def get_triplets(self, embeddings, labels):
if self.cpu:
embeddings = embeddings.cpu()
distance_matrix = pdist(embeddings)
distance_matrix = distance_matrix.cpu()

labels = labels.cpu().data.numpy()
triplets = []

for label in set(labels): #多少种种类
label_mask = (labels == label)
label_indices = np.where(label_mask)[0] #同类的index
if len(label_indices) < 2:
continue
negative_indices = np.where(np.logical_not(label_mask))[0] #逻辑非(取反) # 不同类的index
anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs # 列出所有anchor_positives对
anchor_positives = np.array(anchor_positives)

ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] # 在distance_matrix(batch*batch)里找到(a,p)对的距离
for anchor_positive, ap_distance in zip(anchor_positives, ap_distances):
loss_values = ap_distance - distance_matrix[
torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin #和所有不同类的距离
loss_values = loss_values.data.cpu().numpy() # 一个向量当前a,p和所有不同类a,n的距离
hard_negative = self.negative_selection_fn(loss_values) #在loss大于0的众多loss中选一个,作为hard_negative
if hard_negative is not None:
hard_negative = negative_indices[hard_negative]
triplets.append([anchor_positive[0], anchor_positive[1], hard_negative])

if len(triplets) == 0:
triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]])

triplets = np.array(triplets)

return torch.LongTensor(triplets)

a=FunctionNegativeTripletSelector(margin=0.5,negative_selection_fn=random_hard_negative,cpu=True)
embeddings=torch.randn(6,3)
labels=torch.Tensor([1,0,0,0,1,1])
b = a.get_triplets(embeddings,labels)

计算triplet loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
    def __init__(self, margin, triplet_selector):
super().__init__()
self.margin = margin
self.triplet_selector = triplet_selector

def _loss(self, embeddings, target):
embeddings = torch.flatten(embeddings, start_dim=-2)
triplets = self.triplet_selector.get_triplets(embeddings, target)

if embeddings.is_cuda:
triplets = triplets.cuda()

ap_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 1]]).pow(2).sum(1) # .pow(.5)
an_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 2]]).pow(2).sum(1) # .pow(.5)
losses = F.relu(ap_distances - an_distances + self.margin)

triplet_loss = OnlineTripletLoss(args.margin, RandomNegativeTripletSelector(args.margin))
encoded = l2_regularizer(embeds=encoded)
train_loss = triplet_loss(embeds=encoded, targets=commands)