mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
218 lines
8.9 KiB
Python
218 lines
8.9 KiB
Python
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||
|
||
from contextlib import contextmanager
|
||
from distutils.version import LooseVersion
|
||
from typing import Dict
|
||
from typing import Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from typeguard import check_argument_types
|
||
|
||
from funasr.modules.eend_ola.encoder import TransformerEncoder
|
||
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
|
||
from funasr.torch_utils.device_funcs import force_gatherable
|
||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||
|
||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||
pass
|
||
else:
|
||
# Nothing to do if torch<1.6.0
|
||
@contextmanager
|
||
def autocast(enabled=True):
|
||
yield
|
||
|
||
|
||
class DiarEENDOLAModel(AbsESPnetModel):
|
||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||
|
||
def __init__(
|
||
self,
|
||
encoder: TransformerEncoder,
|
||
eda: EncoderDecoderAttractor,
|
||
n_units: int = 256,
|
||
max_n_speaker: int = 8,
|
||
attractor_loss_weight: float = 1.0,
|
||
mapping_dict=None,
|
||
**kwargs,
|
||
):
|
||
assert check_argument_types()
|
||
|
||
super().__init__()
|
||
self.encoder = encoder
|
||
self.eda = eda
|
||
self.attractor_loss_weight = attractor_loss_weight
|
||
self.max_n_speaker = max_n_speaker
|
||
if mapping_dict is None:
|
||
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
|
||
self.mapping_dict = mapping_dict
|
||
# PostNet
|
||
self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
||
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
||
|
||
def forward(
|
||
self,
|
||
speech: torch.Tensor,
|
||
speech_lengths: torch.Tensor,
|
||
text: torch.Tensor,
|
||
text_lengths: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||
"""Frontend + Encoder + Decoder + Calc loss
|
||
|
||
Args:
|
||
speech: (Batch, Length, ...)
|
||
speech_lengths: (Batch, )
|
||
text: (Batch, Length)
|
||
text_lengths: (Batch,)
|
||
"""
|
||
assert text_lengths.dim() == 1, text_lengths.shape
|
||
# Check that batch_size is unified
|
||
assert (
|
||
speech.shape[0]
|
||
== speech_lengths.shape[0]
|
||
== text.shape[0]
|
||
== text_lengths.shape[0]
|
||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||
batch_size = speech.shape[0]
|
||
|
||
# for data-parallel
|
||
text = text[:, : text_lengths.max()]
|
||
|
||
# 1. Encoder
|
||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||
intermediate_outs = None
|
||
if isinstance(encoder_out, tuple):
|
||
intermediate_outs = encoder_out[1]
|
||
encoder_out = encoder_out[0]
|
||
|
||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||
loss_ctc, cer_ctc = None, None
|
||
stats = dict()
|
||
|
||
# 1. CTC branch
|
||
if self.ctc_weight != 0.0:
|
||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||
encoder_out, encoder_out_lens, text, text_lengths
|
||
)
|
||
|
||
# Collect CTC branch stats
|
||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||
stats["cer_ctc"] = cer_ctc
|
||
|
||
# Intermediate CTC (optional)
|
||
loss_interctc = 0.0
|
||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||
for layer_idx, intermediate_out in intermediate_outs:
|
||
# we assume intermediate_out has the same length & padding
|
||
# as those of encoder_out
|
||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||
intermediate_out, encoder_out_lens, text, text_lengths
|
||
)
|
||
loss_interctc = loss_interctc + loss_ic
|
||
|
||
# Collect Intermedaite CTC stats
|
||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||
loss_ic.detach() if loss_ic is not None else None
|
||
)
|
||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||
|
||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||
|
||
# calculate whole encoder loss
|
||
loss_ctc = (
|
||
1 - self.interctc_weight
|
||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||
|
||
# 2b. Attention decoder branch
|
||
if self.ctc_weight != 1.0:
|
||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||
encoder_out, encoder_out_lens, text, text_lengths
|
||
)
|
||
|
||
# 3. CTC-Att loss definition
|
||
if self.ctc_weight == 0.0:
|
||
loss = loss_att
|
||
elif self.ctc_weight == 1.0:
|
||
loss = loss_ctc
|
||
else:
|
||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||
|
||
# Collect Attn branch stats
|
||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||
stats["acc"] = acc_att
|
||
stats["cer"] = cer_att
|
||
stats["wer"] = wer_att
|
||
|
||
# Collect total loss stats
|
||
stats["loss"] = torch.clone(loss.detach())
|
||
|
||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||
return loss, stats, weight
|
||
|
||
def estimate_sequential(self,
|
||
speech: torch.Tensor,
|
||
speech_lengths: torch.Tensor,
|
||
n_speakers: int,
|
||
shuffle: bool,
|
||
threshold: float,
|
||
**kwargs):
|
||
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
||
emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)]
|
||
if shuffle:
|
||
orders = [np.arange(e.shape[0]) for e in emb]
|
||
for order in orders:
|
||
np.random.shuffle(order)
|
||
# e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换
|
||
# attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)]
|
||
# probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ]
|
||
attractors, probs = self.eda.estimate(
|
||
[e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)])
|
||
else:
|
||
attractors, probs = self.eda.estimate(emb)
|
||
attractors_active = []
|
||
for p, att, e in zip(probs, attractors, emb):
|
||
if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys
|
||
# TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定
|
||
# raise NotImplementedError
|
||
att = att[:n_speakers, ]
|
||
attractors_active.append(att)
|
||
elif threshold is not None:
|
||
silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数
|
||
n_spk = silence[0] if silence.size else None
|
||
att = att[:n_spk, ]
|
||
attractors_active.append(att)
|
||
else:
|
||
NotImplementedError('n_speakers or th has to be given.')
|
||
raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB]
|
||
attractors = [
|
||
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
|
||
for att in attractors_active]
|
||
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
|
||
# ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)]
|
||
logits = self.cal_postnet(ys, self.max_n_speaker)
|
||
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
|
||
zip(logits, raw_n_speakers)]
|
||
|
||
return ys, emb, attractors, raw_n_speakers
|
||
|
||
def recover_y_from_powerlabel(self, logit, n_speaker):
|
||
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, )
|
||
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
||
for i in oov_index:
|
||
if i > 0:
|
||
pred[i] = pred[i - 1]
|
||
else:
|
||
pred[i] = 0
|
||
pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
|
||
# print(pred)
|
||
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
||
decisions = torch.from_numpy(
|
||
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
|
||
torch.float32)
|
||
decisions = decisions[:, :n_speaker]
|
||
return decisions
|