add training related code for sond

This commit is contained in:
志浩 2023-02-15 11:51:27 +08:00
parent 0f6296ff12
commit 5da92c1fa9
3 changed files with 136 additions and 52 deletions

View File

@ -79,4 +79,4 @@ class LabelAggregate(torch.nn.Module):
else:
olens = None
return output, olens
return output.to(input.dtype), olens

View File

@ -8,6 +8,7 @@
import torch
from torch import nn
from funasr.modules.nets_utils import make_pad_mask
class LabelSmoothingLoss(nn.Module):
@ -61,3 +62,20 @@ class LabelSmoothingLoss(nn.Module):
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
class SequenceBinaryCrossEntropy(nn.Module):
def __init__(
self,
normalize_length=False,
criterion=nn.BCEWithLogitsLoss(reduction="none")
):
super().__init__()
self.normalize_length = normalize_length
self.criterion = criterion
def forward(self, pred, label, lengths):
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom

View File

@ -7,7 +7,7 @@ from distutils.version import LooseVersion
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Tuple, List
import numpy as np
import torch
@ -23,6 +23,8 @@ from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -54,7 +56,10 @@ class DiarSondModel(AbsESPnetModel):
length_normalized_loss: bool = False,
max_spk_num: int = 16,
label_aggregator: Optional[torch.nn.Module] = None,
normlize_speech_speaker: bool = False,
normalize_speech_speaker: bool = False,
ignore_id: int = -1,
speaker_discrimination_loss_weight: float = 1.0,
inter_score_loss_weight: float = 0.0
):
assert check_argument_types()
@ -71,7 +76,25 @@ class DiarSondModel(AbsESPnetModel):
self.decoder = decoder
self.token_list = token_list
self.max_spk_num = max_spk_num
self.normalize_speech_speaker = normlize_speech_speaker
self.normalize_speech_speaker = normalize_speech_speaker
self.ignore_id = ignore_id
self.criterion_diar = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
self.pse_embedding = self.generate_pse_embedding()
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
def generate_pse_embedding(self):
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
for idx, pse_label in enumerate(self.token_list):
emb = int2vec(pse_label, vec_dim=self.max_spk_num, dtype=np.float)
embedding[idx] = emb
return torch.from_numpy(embedding)
def forward(
self,
@ -85,7 +108,7 @@ class DiarSondModel(AbsESPnetModel):
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
Args:
speech: (Batch, samples)
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
because the chunk-iterator does not
have the speech_lengths returned.
@ -93,63 +116,42 @@ class DiarSondModel(AbsESPnetModel):
espnet2/iterators/chunk_iter_factory.py
profile: (Batch, N_spk, dim)
profile_lengths: (Batch,)
spk_labels: (Batch, )
spk_labels: (Batch, frames, input_size)
spk_labels_lengths: (Batch,)
"""
assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# 1. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
profile, profile_lengths,
return_inter_outputs=True
)
(speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
if self.attractor is None:
# 2a. Decoder (baiscally a predction layer after encoder_out)
pred = self.decoder(encoder_out, encoder_out_lens)
else:
# 2b. Encoder Decoder Attractors
# Shuffle the chronological order of encoder_out, then calculate attractor
encoder_out_shuffled = encoder_out.clone()
for i in range(len(encoder_out_lens)):
encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
i, torch.randperm(encoder_out_lens[i]), :
]
attractor, att_prob = self.attractor(
encoder_out_shuffled,
encoder_out_lens,
to_device(
self,
torch.zeros(
encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
),
),
)
# Remove the final attractor which does not correspond to a speaker
# Then multiply the attractors and encoder_out
pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
# 3. Aggregate time-domain labels
# 2. Aggregate time-domain labels to match forward outputs
if self.label_aggregator is not None:
spk_labels, spk_labels_lengths = self.label_aggregator(
spk_labels, spk_labels_lengths
spk_labels.unsqueeze(2), spk_labels_lengths
)
spk_labels = spk_labels.squeeze(2)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slighly less than the
# the sequence length of 'pred' might be slightly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
length_diff = spk_labels.shape[1] - pred.shape[1]
if length_diff > 0 and length_diff <= length_diff_tolerance:
spk_labels = spk_labels[:, 0 : pred.shape[1], :]
if 0 < length_diff <= length_diff_tolerance:
spk_labels = spk_labels[:, 0: pred.shape[1], :]
loss_diar = self.classification_loss(pred, spk_labels, spk_labels_lengths)
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, spk_labels, spk_labels_lengths)
label_mask = make_pad_mask(spk_labels_lengths, maxlen=spk_labels.shape[1])
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
if self.attractor is None:
loss_pit, loss_att = None, None
loss, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
else:
loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
loss_att = self.attractor_loss(att_prob, spk_labels)
loss = loss_pit + self.attractor_weight * loss_att
(
correct,
num_frames,
@ -160,7 +162,11 @@ class DiarSondModel(AbsESPnetModel):
speaker_miss,
speaker_falarm,
speaker_error,
) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
) = self.calc_diarization_error(
pred=F.embedding(pred.argmax(dim=2) * label_mask, self.pse_embedding),
label=F.embedding(spk_labels * label_mask, self.pse_embedding),
length=spk_labels_lengths
)
if speech_scored > 0 and num_frames > 0:
sad_mr, sad_fr, mi, fa, cf, acc, der = (
@ -177,8 +183,10 @@ class DiarSondModel(AbsESPnetModel):
stats = dict(
loss=loss.detach(),
loss_att=loss_att.detach() if loss_att is not None else None,
loss_pit=loss_pit.detach() if loss_pit is not None else None,
loss_diar=loss_diar.detach() if loss_diar is not None else None,
loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
@ -191,6 +199,61 @@ class DiarSondModel(AbsESPnetModel):
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def classification_loss(
self,
predictions: torch.Tensor,
labels: torch.Tensor,
prediction_lengths: torch.Tensor
) -> torch.Tensor:
pad_labels = labels.masked_fill(
make_pad_mask(prediction_lengths, maxlen=labels.shape[1]),
value=self.ignore_id
)
loss = self.criterion_diar(predictions, pad_labels)
return loss
def speaker_discrimination_loss(
self,
profile: torch.Tensor,
profile_lengths: torch.Tensor
) -> torch.Tensor:
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() # (B, N, 1)
mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N)
mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0))
eps = 1e-12
coding_norm = torch.linalg.norm(
profile * profile_mask + (1 - profile_mask) * eps,
dim=2, keepdim=True
) * profile_mask
cos_theta = F.cosine_similarity(profile, profile, dim=2, eps=eps) * mask
cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
return loss
def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
padding_labels = pse_labels.masked_fill(
make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1]),
value=0
).to(pse_labels.dtype)
multi_labels = F.embedding(padding_labels, self.pse_embedding)
return multi_labels
def internal_score_loss(
self,
cd_score: torch.Tensor,
ci_score: torch.Tensor,
pse_labels: torch.Tensor,
pse_labels_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
return ci_loss, cd_loss
def collect_feats(
self,
speech: torch.Tensor,
@ -282,7 +345,8 @@ class DiarSondModel(AbsESPnetModel):
speech_lengths: torch.Tensor,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
) -> torch.Tensor:
return_inter_outputs: bool = False,
) -> [torch.Tensor, Optional[list]]:
# speech encoding
speech, speech_lengths = self.encode_speech(speech, speech_lengths)
# speaker encoding
@ -292,6 +356,8 @@ class DiarSondModel(AbsESPnetModel):
# post net forward
logits = self.post_net_forward(similarity, speech_lengths)
if return_inter_outputs:
return logits, [(speech, speech_lengths), (profile, profile_lengths), torch.split(similarity, 2)]
return logits
def encode(