This commit is contained in:
嘉渊 2023-04-27 19:27:49 +08:00
parent 6997763bf6
commit 607073619c
10 changed files with 579 additions and 269 deletions

View File

@ -12,7 +12,11 @@ from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
@ -30,11 +34,11 @@ class Data2VecPretrainModel(FunASRModel):
def __init__(
self,
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
encoder: AbsEncoder,
):
assert check_argument_types()
@ -53,7 +57,6 @@ class Data2VecPretrainModel(FunASRModel):
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -102,7 +105,6 @@ class Data2VecPretrainModel(FunASRModel):
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )

View File

@ -13,18 +13,22 @@ from typing import Union
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.base_model import FunASRModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -43,9 +47,11 @@ class ESPnetASRModel(FunASRModel):
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -127,7 +133,6 @@ class ESPnetASRModel(FunASRModel):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -243,7 +248,6 @@ class ESPnetASRModel(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -325,9 +329,7 @@ class ESPnetASRModel(FunASRModel):
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@ -364,7 +366,6 @@ class ESPnetASRModel(FunASRModel):
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:

View File

@ -17,10 +17,13 @@ from funasr.losses.label_smoothing_loss import (
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.base_model import FunASRModel
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -32,30 +35,36 @@ else:
import pdb
import random
import math
class MFCCA(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
"""
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
https://arxiv.org/abs/2210.05265
"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
decoder: AbsDecoder,
ctc: CTC,
rnnt_decoder: None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
mask_ratio: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
decoder: AbsDecoder,
ctc: CTC,
rnnt_decoder: None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
mask_ratio: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@ -69,10 +78,9 @@ class MFCCA(FunASRModel):
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.token_list = token_list.copy()
self.mask_ratio = mask_ratio
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
@ -106,14 +114,13 @@ class MFCCA(FunASRModel):
self.error_calculator = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -123,22 +130,22 @@ class MFCCA(FunASRModel):
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
#pdb.set_trace()
if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
# pdb.set_trace()
if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
rate_num = random.random()
#rate_num = 0.1
if(rate_num<=self.mask_ratio):
retain_channel = math.ceil(random.random() *8)
if(retain_channel>1):
speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
# rate_num = 0.1
if (rate_num <= self.mask_ratio):
retain_channel = math.ceil(random.random() * 8)
if (retain_channel > 1):
speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
else:
speech = speech[:,:,torch.randperm(8)[0]]
#pdb.set_trace()
speech = speech[:, :, torch.randperm(8)[0]]
# pdb.set_trace()
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@ -188,20 +195,19 @@ class MFCCA(FunASRModel):
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -220,14 +226,14 @@ class MFCCA(FunASRModel):
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
#pdb.set_trace()
# pdb.set_trace()
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
if(encoder_out.dim()==4):
if (encoder_out.dim() == 4):
assert encoder_out.size(2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
@ -241,7 +247,7 @@ class MFCCA(FunASRModel):
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
@ -259,11 +265,11 @@ class MFCCA(FunASRModel):
return feats, feats_lengths, channel_size
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
@ -291,14 +297,14 @@ class MFCCA(FunASRModel):
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
if(encoder_out.dim()==4):
if (encoder_out.dim() == 4):
encoder_out = encoder_out.mean(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
@ -310,10 +316,10 @@ class MFCCA(FunASRModel):
return loss_ctc, cer_ctc
def _calc_rnnt_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
raise NotImplementedError
raise NotImplementedError

View File

@ -12,23 +12,26 @@ import random
import numpy as np
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.predictor.cif import mae_loss
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.base_model import FunASRModel
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@ -40,7 +43,7 @@ else:
class Paraformer(FunASRModel):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@ -49,10 +52,12 @@ class Paraformer(FunASRModel):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
encoder: torch.nn.Module,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@ -92,8 +97,17 @@ class Paraformer(FunASRModel):
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.encoder = encoder
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.error_calculator = None
if ctc_weight == 1.0:
@ -138,7 +152,6 @@ class Paraformer(FunASRModel):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -161,7 +174,9 @@ class Paraformer(FunASRModel):
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
@ -179,6 +194,30 @@ class Paraformer(FunASRModel):
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
@ -229,7 +268,6 @@ class Paraformer(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -246,8 +284,29 @@ class Paraformer(FunASRModel):
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@ -258,45 +317,18 @@ class Paraformer(FunASRModel):
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def encode_chunk(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
return encoder_out, torch.tensor([encoder_out.size(1)])
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def calc_predictor_chunk(self, encoder_out, cache=None):
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
encoder_out_mask,
ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@ -308,14 +340,6 @@ class Paraformer(FunASRModel):
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
)
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -342,9 +366,7 @@ class Paraformer(FunASRModel):
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@ -381,7 +403,6 @@ class Paraformer(FunASRModel):
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@ -521,9 +542,186 @@ class Paraformer(FunASRModel):
return loss_ctc, cer_ctc
class ParaformerBert(Paraformer):
class ParaformerOnline(Paraformer):
"""
Author: Speech Lab, Alibaba Group, China
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self, *args, **kwargs,
):
super().__init__(*args, **kwargs)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
self.step_cur += 1
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, :speech_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode_chunk(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, torch.tensor([encoder_out.size(1)])
def calc_predictor_chunk(self, encoder_out, cache=None):
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
)
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out
class ParaformerBert(Paraformer):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
"""
@ -531,11 +729,11 @@ class ParaformerBert(Paraformer):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@ -690,7 +888,6 @@ class ParaformerBert(Paraformer):
embed_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -799,74 +996,73 @@ class ParaformerBert(Paraformer):
class BiCifParaformer(Paraformer):
"""
Paraformer model with an extra cif predictor
to conduct accurate timestamp prediction
"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
predictor = None,
predictor_weight: float = 0.0,
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
predictor=None,
predictor_weight: float = 0.0,
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
blank_id=blank_id,
sos=sos,
eos=eos,
lsm_weight=lsm_weight,
length_normalized_loss=length_normalized_loss,
report_cer=report_cer,
report_wer=report_wer,
sym_space=sym_space,
sym_blank=sym_blank,
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
predictor=predictor,
predictor_weight=predictor_weight,
predictor_bias=predictor_bias,
sampling_ratio=sampling_ratio,
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
blank_id=blank_id,
sos=sos,
eos=eos,
lsm_weight=lsm_weight,
length_normalized_loss=length_normalized_loss,
report_cer=report_cer,
report_wer=report_wer,
sym_space=sym_space,
sym_blank=sym_blank,
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
predictor=predictor,
predictor_weight=predictor_weight,
predictor_bias=predictor_bias,
sampling_ratio=sampling_ratio,
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
@ -888,21 +1084,77 @@ class BiCifParaformer(Paraformer):
loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
return loss_pre2
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
encoder_out_mask,
ignore_id=self.ignore_id)
# 0. sampler
decoder_out_1st = None
if self.sampling_ratio > 0.0:
if self.step_cur < 2:
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
pre_acoustic_embeds)
else:
if self.step_cur < 2:
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
sematic_embeds = pre_acoustic_embeds
# 1. Forward decoder
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_pad)
acc_att = th_accuracy(
decoder_out_1st.view(-1, self.vocab_size),
ys_pad,
ignore_label=self.ignore_id,
)
loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out_1st.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, loss_pre
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
ignore_id=self.ignore_id)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
None,
encoder_out_mask,
ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
encoder_out_mask,
token_num)
encoder_out_mask,
token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
@ -913,7 +1165,6 @@ class BiCifParaformer(Paraformer):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -996,7 +1247,8 @@ class BiCifParaformer(Paraformer):
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
loss = self.ctc_weight * loss_ctc + (
1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@ -1022,11 +1274,11 @@ class ContextualParaformer(Paraformer):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@ -1120,7 +1372,6 @@ class ContextualParaformer(Paraformer):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -1504,4 +1755,4 @@ class ContextualParaformer(Paraformer):
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update
return var_dict_torch_update

View File

@ -15,8 +15,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.models.base_model import FunASRModel
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@ -91,7 +91,6 @@ class DiarEENDOLAModel(FunASRModel):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )

View File

@ -14,9 +14,15 @@ import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.base_model import FunASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@ -30,16 +36,20 @@ else:
class DiarSondModel(FunASRModel):
"""Speaker overlap-aware neural diarization model
reference: https://arxiv.org/abs/2211.10243
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
def __init__(
self,
vocab_size: int,
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: torch.nn.Module,
speaker_encoder: Optional[torch.nn.Module],
ci_scorer: torch.nn.Module,
@ -105,7 +115,6 @@ class DiarSondModel(FunASRModel):
binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
Args:
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
@ -342,7 +351,7 @@ class DiarSondModel(FunASRModel):
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
if isinstance(self.ci_scorer, torch.nn.Module):
if isinstance(self.ci_scorer, AbsEncoder):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
else:
@ -381,7 +390,6 @@ class DiarSondModel(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
@ -481,4 +489,4 @@ class DiarSondModel(FunASRModel):
speaker_miss,
speaker_falarm,
speaker_error,
)
)

View File

@ -1,3 +1,8 @@
"""
Author: Speech Lab, Alibaba Group, China
"""
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
@ -10,11 +15,22 @@ from typing import Union
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.base_model import FunASRModel
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -32,11 +48,11 @@ class ESPnetSVModel(FunASRModel):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
pooling_layer: torch.nn.Module,
decoder: AbsDecoder,
@ -65,7 +81,6 @@ class ESPnetSVModel(FunASRModel):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -206,7 +221,6 @@ class ESPnetSVModel(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -256,4 +270,4 @@ class ESPnetSVModel(FunASRModel):
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
return feats, feats_lengths

View File

@ -2,20 +2,24 @@ import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@ -25,15 +29,15 @@ else:
yield
class TimestampPredictor(FunASRModel):
class TimestampPredictor(AbsESPnetModel):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
frontend: Optional[torch.nn.Module],
encoder: torch.nn.Module,
frontend: Optional[AbsFrontend],
encoder: AbsEncoder,
predictor: CifPredictorV3,
predictor_bias: int = 0,
token_list=None,
@ -51,7 +55,7 @@ class TimestampPredictor(FunASRModel):
self.predictor_bias = predictor_bias
self.criterion_pre = mae_loss()
self.token_list = token_list
def forward(
self,
speech: torch.Tensor,
@ -60,7 +64,6 @@ class TimestampPredictor(FunASRModel):
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -108,7 +111,6 @@ class TimestampPredictor(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -123,7 +125,7 @@ class TimestampPredictor(FunASRModel):
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -146,8 +148,8 @@ class TimestampPredictor(FunASRModel):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
encoder_out_mask,
token_num)
encoder_out_mask,
token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def collect_feats(

View File

@ -17,10 +17,13 @@ from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
@ -37,18 +40,18 @@ else:
class UniASR(FunASRModel):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: torch.nn.Module,
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@ -176,7 +179,6 @@ class UniASR(FunASRModel):
decoding_ind: int = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -466,7 +468,6 @@ class UniASR(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -530,7 +531,6 @@ class UniASR(FunASRModel):
ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@ -624,9 +624,7 @@ class UniASR(FunASRModel):
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@ -663,7 +661,6 @@ class UniASR(FunASRModel):
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@ -1069,4 +1066,3 @@ class UniASR(FunASRModel):
ys_hat = self.ctc2.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc

View File

@ -35,6 +35,12 @@ class VadDetectMode(Enum):
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
@ -99,6 +105,12 @@ class VADXOptions:
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
@ -117,6 +129,12 @@ class E2EVadSpeechBufWithDoa(object):
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
@ -126,6 +144,12 @@ class E2EVadFrameProb(object):
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
@ -192,6 +216,12 @@ class WindowDetector(object):
class E2EVadModel(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
@ -286,7 +316,7 @@ class E2EVadModel(nn.Module):
0.000001))
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
scores = self.encoder(feats, in_cache) # return B * T * D
scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
@ -444,7 +474,7 @@ class E2EVadModel(nn.Module):
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@ -460,8 +490,9 @@ class E2EVadModel(nn.Module):
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point:
if not is_final and (
not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
@ -474,11 +505,11 @@ class E2EVadModel(nn.Module):
return segments, in_cache
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False, max_end_sil: int = 800
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
is_final: bool = False, max_end_sil: int = 800
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final: