mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
6997763bf6
commit
607073619c
@ -12,7 +12,11 @@ from typing import Tuple
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
@ -30,11 +34,11 @@ class Data2VecPretrainModel(FunASRModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
encoder: AbsEncoder,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
@ -53,7 +57,6 @@ class Data2VecPretrainModel(FunASRModel):
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -102,7 +105,6 @@ class Data2VecPretrainModel(FunASRModel):
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
"""Frontend + Encoder.
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
|
||||
@ -13,18 +13,22 @@ from typing import Union
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
@ -43,9 +47,11 @@ class ESPnetASRModel(FunASRModel):
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
ctc_weight: float = 0.5,
|
||||
@ -127,7 +133,6 @@ class ESPnetASRModel(FunASRModel):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -243,7 +248,6 @@ class ESPnetASRModel(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -325,9 +329,7 @@ class ESPnetASRModel(FunASRModel):
|
||||
ys_pad_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
@ -364,7 +366,6 @@ class ESPnetASRModel(FunASRModel):
|
||||
batch_size: int = 100,
|
||||
):
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
|
||||
@ -17,10 +17,13 @@ from funasr.losses.label_smoothing_loss import (
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
@ -32,30 +35,36 @@ else:
|
||||
import pdb
|
||||
import random
|
||||
import math
|
||||
|
||||
|
||||
class MFCCA(FunASRModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
"""
|
||||
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
|
||||
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
|
||||
https://arxiv.org/abs/2210.05265
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
rnnt_decoder: None,
|
||||
ctc_weight: float = 0.5,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
mask_ratio: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
rnnt_decoder: None,
|
||||
ctc_weight: float = 0.5,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
mask_ratio: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
@ -69,10 +78,9 @@ class MFCCA(FunASRModel):
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
|
||||
self.mask_ratio = mask_ratio
|
||||
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
@ -106,14 +114,13 @@ class MFCCA(FunASRModel):
|
||||
self.error_calculator = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -123,22 +130,22 @@ class MFCCA(FunASRModel):
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
#pdb.set_trace()
|
||||
if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
|
||||
# pdb.set_trace()
|
||||
if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
|
||||
rate_num = random.random()
|
||||
#rate_num = 0.1
|
||||
if(rate_num<=self.mask_ratio):
|
||||
retain_channel = math.ceil(random.random() *8)
|
||||
if(retain_channel>1):
|
||||
speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
|
||||
# rate_num = 0.1
|
||||
if (rate_num <= self.mask_ratio):
|
||||
retain_channel = math.ceil(random.random() * 8)
|
||||
if (retain_channel > 1):
|
||||
speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
|
||||
else:
|
||||
speech = speech[:,:,torch.randperm(8)[0]]
|
||||
#pdb.set_trace()
|
||||
speech = speech[:, :, torch.randperm(8)[0]]
|
||||
# pdb.set_trace()
|
||||
batch_size = speech.shape[0]
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
@ -188,20 +195,19 @@ class MFCCA(FunASRModel):
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -220,14 +226,14 @@ class MFCCA(FunASRModel):
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
#pdb.set_trace()
|
||||
# pdb.set_trace()
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
if(encoder_out.dim()==4):
|
||||
if (encoder_out.dim() == 4):
|
||||
assert encoder_out.size(2) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
@ -241,7 +247,7 @@ class MFCCA(FunASRModel):
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
# for data-parallel
|
||||
@ -259,11 +265,11 @@ class MFCCA(FunASRModel):
|
||||
return feats, feats_lengths, channel_size
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
@ -291,14 +297,14 @@ class MFCCA(FunASRModel):
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
if(encoder_out.dim()==4):
|
||||
if (encoder_out.dim() == 4):
|
||||
encoder_out = encoder_out.mean(1)
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
@ -310,10 +316,10 @@ class MFCCA(FunASRModel):
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def _calc_rnnt_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
@ -12,23 +12,26 @@ import random
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.e2e_asr_common import ErrorCalculator
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.predictor.cif import mae_loss
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.modules.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.predictor.cif import CifPredictorV3
|
||||
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -40,7 +43,7 @@ else:
|
||||
|
||||
class Paraformer(FunASRModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
@ -49,10 +52,12 @@ class Paraformer(FunASRModel):
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
encoder: torch.nn.Module,
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
ctc_weight: float = 0.5,
|
||||
@ -92,8 +97,17 @@ class Paraformer(FunASRModel):
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.postencoder = postencoder
|
||||
self.encoder = encoder
|
||||
|
||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
self.encoder.conditioning_layer = torch.nn.Linear(
|
||||
vocab_size, self.encoder.output_size()
|
||||
)
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
if ctc_weight == 1.0:
|
||||
@ -138,7 +152,6 @@ class Paraformer(FunASRModel):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -161,7 +174,9 @@ class Paraformer(FunASRModel):
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
@ -179,6 +194,30 @@ class Paraformer(FunASRModel):
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
|
||||
@ -229,7 +268,6 @@ class Paraformer(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -246,8 +284,29 @@ class Paraformer(FunASRModel):
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
||||
feats, feats_lengths, ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
@ -258,45 +317,18 @@ class Paraformer(FunASRModel):
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def encode_chunk(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
||||
|
||||
return encoder_out, torch.tensor([encoder_out.size(1)])
|
||||
|
||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
def calc_predictor_chunk(self, encoder_out, cache=None):
|
||||
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
|
||||
encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
||||
@ -308,14 +340,6 @@ class Paraformer(FunASRModel):
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
|
||||
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
||||
decoder_outs = self.decoder.forward_chunk(
|
||||
encoder_out, sematic_embeds, cache["decoder"]
|
||||
)
|
||||
decoder_out = decoder_outs
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -342,9 +366,7 @@ class Paraformer(FunASRModel):
|
||||
ys_pad_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
@ -381,7 +403,6 @@ class Paraformer(FunASRModel):
|
||||
batch_size: int = 100,
|
||||
):
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
@ -521,9 +542,186 @@ class Paraformer(FunASRModel):
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
|
||||
class ParaformerBert(Paraformer):
|
||||
class ParaformerOnline(Paraformer):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, **kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
self.step_cur += 1
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
speech = speech[:, :speech_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
loss_pre = None
|
||||
stats = dict()
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att + loss_pre * self.predictor_weight
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode_chunk(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
|
||||
feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, torch.tensor([encoder_out.size(1)])
|
||||
|
||||
def calc_predictor_chunk(self, encoder_out, cache=None):
|
||||
|
||||
pre_acoustic_embeds, pre_token_length = \
|
||||
self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
||||
return pre_acoustic_embeds, pre_token_length
|
||||
|
||||
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
||||
decoder_outs = self.decoder.forward_chunk(
|
||||
encoder_out, sematic_embeds, cache["decoder"]
|
||||
)
|
||||
decoder_out = decoder_outs
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out
|
||||
|
||||
|
||||
class ParaformerBert(Paraformer):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
|
||||
"""
|
||||
|
||||
@ -531,11 +729,11 @@ class ParaformerBert(Paraformer):
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
@ -690,7 +888,6 @@ class ParaformerBert(Paraformer):
|
||||
embed_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -799,74 +996,73 @@ class ParaformerBert(Paraformer):
|
||||
|
||||
|
||||
class BiCifParaformer(Paraformer):
|
||||
|
||||
"""
|
||||
Paraformer model with an extra cif predictor
|
||||
to conduct accurate timestamp prediction
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
predictor = None,
|
||||
predictor_weight: float = 0.0,
|
||||
predictor_bias: int = 0,
|
||||
sampling_ratio: float = 0.2,
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
predictor=None,
|
||||
predictor_weight: float = 0.0,
|
||||
predictor_bias: int = 0,
|
||||
sampling_ratio: float = 0.2,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
ctc_weight=ctc_weight,
|
||||
interctc_weight=interctc_weight,
|
||||
ignore_id=ignore_id,
|
||||
blank_id=blank_id,
|
||||
sos=sos,
|
||||
eos=eos,
|
||||
lsm_weight=lsm_weight,
|
||||
length_normalized_loss=length_normalized_loss,
|
||||
report_cer=report_cer,
|
||||
report_wer=report_wer,
|
||||
sym_space=sym_space,
|
||||
sym_blank=sym_blank,
|
||||
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
|
||||
predictor=predictor,
|
||||
predictor_weight=predictor_weight,
|
||||
predictor_bias=predictor_bias,
|
||||
sampling_ratio=sampling_ratio,
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
ctc_weight=ctc_weight,
|
||||
interctc_weight=interctc_weight,
|
||||
ignore_id=ignore_id,
|
||||
blank_id=blank_id,
|
||||
sos=sos,
|
||||
eos=eos,
|
||||
lsm_weight=lsm_weight,
|
||||
length_normalized_loss=length_normalized_loss,
|
||||
report_cer=report_cer,
|
||||
report_wer=report_wer,
|
||||
sym_space=sym_space,
|
||||
sym_blank=sym_blank,
|
||||
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
|
||||
predictor=predictor,
|
||||
predictor_weight=predictor_weight,
|
||||
predictor_bias=predictor_bias,
|
||||
sampling_ratio=sampling_ratio,
|
||||
)
|
||||
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
|
||||
|
||||
@ -888,21 +1084,77 @@ class BiCifParaformer(Paraformer):
|
||||
loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
|
||||
|
||||
return loss_pre2
|
||||
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
if self.predictor_bias == 1:
|
||||
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_pad_lens = ys_pad_lens + self.predictor_bias
|
||||
pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
|
||||
encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
|
||||
# 0. sampler
|
||||
decoder_out_1st = None
|
||||
if self.sampling_ratio > 0.0:
|
||||
if self.step_cur < 2:
|
||||
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
else:
|
||||
if self.step_cur < 2:
|
||||
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
sematic_embeds = pre_acoustic_embeds
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
|
||||
)
|
||||
decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
if decoder_out_1st is None:
|
||||
decoder_out_1st = decoder_out
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out_1st.view(-1, self.vocab_size),
|
||||
ys_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out_1st.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att, loss_pre
|
||||
|
||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
|
||||
None,
|
||||
encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
|
||||
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||
|
||||
def forward(
|
||||
@ -913,7 +1165,6 @@ class BiCifParaformer(Paraformer):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -996,7 +1247,8 @@ class BiCifParaformer(Paraformer):
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
|
||||
loss = self.ctc_weight * loss_ctc + (
|
||||
1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
@ -1022,11 +1274,11 @@ class ContextualParaformer(Paraformer):
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
@ -1120,7 +1372,6 @@ class ContextualParaformer(Paraformer):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -1504,4 +1755,4 @@ class ContextualParaformer(Paraformer):
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
return var_dict_torch_update
|
||||
@ -15,8 +15,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
pass
|
||||
@ -91,7 +91,6 @@ class DiarEENDOLAModel(FunASRModel):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
|
||||
@ -14,9 +14,15 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import to_device
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
|
||||
from funasr.utils.misc import int2vec
|
||||
|
||||
@ -30,16 +36,20 @@ else:
|
||||
|
||||
|
||||
class DiarSondModel(FunASRModel):
|
||||
"""Speaker overlap-aware neural diarization model
|
||||
reference: https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: torch.nn.Module,
|
||||
speaker_encoder: Optional[torch.nn.Module],
|
||||
ci_scorer: torch.nn.Module,
|
||||
@ -105,7 +115,6 @@ class DiarSondModel(FunASRModel):
|
||||
binary_labels_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, samples) or (Batch, frames, input_size)
|
||||
speech_lengths: (Batch,) default None for chunk interator,
|
||||
@ -342,7 +351,7 @@ class DiarSondModel(FunASRModel):
|
||||
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
|
||||
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
|
||||
|
||||
if isinstance(self.ci_scorer, torch.nn.Module):
|
||||
if isinstance(self.ci_scorer, AbsEncoder):
|
||||
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
|
||||
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
|
||||
else:
|
||||
@ -381,7 +390,6 @@ class DiarSondModel(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch,)
|
||||
@ -481,4 +489,4 @@ class DiarSondModel(FunASRModel):
|
||||
speaker_miss,
|
||||
speaker_falarm,
|
||||
speaker_error,
|
||||
)
|
||||
)
|
||||
@ -1,3 +1,8 @@
|
||||
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
@ -10,11 +15,22 @@ from typing import Union
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
@ -32,11 +48,11 @@ class ESPnetSVModel(FunASRModel):
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
pooling_layer: torch.nn.Module,
|
||||
decoder: AbsDecoder,
|
||||
@ -65,7 +81,6 @@ class ESPnetSVModel(FunASRModel):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -206,7 +221,6 @@ class ESPnetSVModel(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -256,4 +270,4 @@ class ESPnetSVModel(FunASRModel):
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
return feats, feats_lengths
|
||||
@ -2,20 +2,24 @@ import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.predictor.cif import mae_loss
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.modules.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.models.predictor.cif import CifPredictorV3
|
||||
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -25,15 +29,15 @@ else:
|
||||
yield
|
||||
|
||||
|
||||
class TimestampPredictor(FunASRModel):
|
||||
class TimestampPredictor(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[torch.nn.Module],
|
||||
encoder: torch.nn.Module,
|
||||
frontend: Optional[AbsFrontend],
|
||||
encoder: AbsEncoder,
|
||||
predictor: CifPredictorV3,
|
||||
predictor_bias: int = 0,
|
||||
token_list=None,
|
||||
@ -51,7 +55,7 @@ class TimestampPredictor(FunASRModel):
|
||||
self.predictor_bias = predictor_bias
|
||||
self.criterion_pre = mae_loss()
|
||||
self.token_list = token_list
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
@ -60,7 +64,6 @@ class TimestampPredictor(FunASRModel):
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -108,7 +111,6 @@ class TimestampPredictor(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -123,7 +125,7 @@ class TimestampPredictor(FunASRModel):
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -146,8 +148,8 @@ class TimestampPredictor(FunASRModel):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||
|
||||
def collect_feats(
|
||||
|
||||
@ -17,10 +17,13 @@ from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
|
||||
@ -37,18 +40,18 @@ else:
|
||||
|
||||
class UniASR(FunASRModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: torch.nn.Module,
|
||||
encoder: AbsEncoder,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
@ -176,7 +179,6 @@ class UniASR(FunASRModel):
|
||||
decoding_ind: int = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -466,7 +468,6 @@ class UniASR(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -530,7 +531,6 @@ class UniASR(FunASRModel):
|
||||
ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
@ -624,9 +624,7 @@ class UniASR(FunASRModel):
|
||||
ys_pad_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
@ -663,7 +661,6 @@ class UniASR(FunASRModel):
|
||||
batch_size: int = 100,
|
||||
):
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
@ -1069,4 +1066,3 @@ class UniASR(FunASRModel):
|
||||
ys_hat = self.ctc2.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
|
||||
@ -35,6 +35,12 @@ class VadDetectMode(Enum):
|
||||
|
||||
|
||||
class VADXOptions:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
@ -99,6 +105,12 @@ class VADXOptions:
|
||||
|
||||
|
||||
class E2EVadSpeechBufWithDoa(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
@ -117,6 +129,12 @@ class E2EVadSpeechBufWithDoa(object):
|
||||
|
||||
|
||||
class E2EVadFrameProb(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.noise_prob = 0.0
|
||||
self.speech_prob = 0.0
|
||||
@ -126,6 +144,12 @@ class E2EVadFrameProb(object):
|
||||
|
||||
|
||||
class WindowDetector(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
|
||||
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
|
||||
speech_to_sil_time: int, frame_size_ms: int):
|
||||
self.window_size_ms = window_size_ms
|
||||
@ -192,6 +216,12 @@ class WindowDetector(object):
|
||||
|
||||
|
||||
class E2EVadModel(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
|
||||
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
|
||||
super(E2EVadModel, self).__init__()
|
||||
self.vad_opts = VADXOptions(**vad_post_args)
|
||||
@ -286,7 +316,7 @@ class E2EVadModel(nn.Module):
|
||||
0.000001))
|
||||
|
||||
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
|
||||
scores = self.encoder(feats, in_cache) # return B * T * D
|
||||
scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D
|
||||
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
|
||||
self.vad_opts.nn_eval_block_size = scores.shape[1]
|
||||
self.frm_cnt += scores.shape[1] # count total frames
|
||||
@ -444,7 +474,7 @@ class E2EVadModel(nn.Module):
|
||||
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
|
||||
|
||||
return frame_state
|
||||
|
||||
|
||||
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||
is_final: bool = False
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
@ -460,8 +490,9 @@ class E2EVadModel(nn.Module):
|
||||
segment_batch = []
|
||||
if len(self.output_data_buf) > 0:
|
||||
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
||||
if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
|
||||
i].contain_seg_end_point:
|
||||
if not is_final and (
|
||||
not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
|
||||
i].contain_seg_end_point):
|
||||
continue
|
||||
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
|
||||
segment_batch.append(segment)
|
||||
@ -474,11 +505,11 @@ class E2EVadModel(nn.Module):
|
||||
return segments, in_cache
|
||||
|
||||
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||
is_final: bool = False, max_end_sil: int = 800
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
is_final: bool = False, max_end_sil: int = 800
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
||||
self.waveform = waveform # compute decibel for each frame
|
||||
|
||||
|
||||
self.ComputeScores(feats, in_cache)
|
||||
self.ComputeDecibel()
|
||||
if not is_final:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user