diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml index 59f993636..a1f27a3dd 100644 --- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml +++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml @@ -6,7 +6,7 @@ encoder_conf: unified_model_training: true default_chunk_size: 16 jitter_range: 4 - left_chunk_size: 0 + left_chunk_size: 1 embed_vgg_like: false subsampling_factor: 4 linear_units: 2048 @@ -51,7 +51,7 @@ use_amp: true # optimization related accum_grad: 1 grad_clip: 5 -max_epoch: 200 +max_epoch: 120 val_scheduler_criterion: - valid - loss diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py index c752fe618..9b25fe577 100644 --- a/funasr/fileio/sound_scp.py +++ b/funasr/fileio/sound_scp.py @@ -1,6 +1,6 @@ import collections.abc from pathlib import Path -from typing import Union +from typing import List, Tuple, Union import random import numpy as np @@ -13,6 +13,74 @@ import torchaudio from funasr.fileio.read_text import read_2column_text +def soundfile_read( + wavs: Union[str, List[str]], + dtype=None, + always_2d: bool = False, + concat_axis: int = 1, + start: int = 0, + end: int = None, + return_subtype: bool = False, +) -> Tuple[np.array, int]: + if isinstance(wavs, str): + wavs = [wavs] + + arrays = [] + subtypes = [] + prev_rate = None + prev_wav = None + for wav in wavs: + with soundfile.SoundFile(wav) as f: + f.seek(start) + if end is not None: + frames = end - start + else: + frames = -1 + if dtype == "float16": + array = f.read( + frames, + dtype="float32", + always_2d=always_2d, + ).astype(dtype) + else: + array = f.read(frames, dtype=dtype, always_2d=always_2d) + rate = f.samplerate + subtype = f.subtype + subtypes.append(subtype) + + if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1: + # array: (Time, Channel) + array = array[:, None] + + if prev_wav is not None: + if prev_rate != rate: + raise RuntimeError( + f"'{prev_wav}' and '{wav}' have mismatched sampling rate: " + f"{prev_rate} != {rate}" + ) + + dim1 = arrays[0].shape[1 - concat_axis] + dim2 = array.shape[1 - concat_axis] + if dim1 != dim2: + raise RuntimeError( + "Shapes must match with " + f"{1 - concat_axis} axis, but gut {dim1} and {dim2}" + ) + + prev_rate = rate + prev_wav = wav + arrays.append(array) + + if len(arrays) == 1: + array = arrays[0] + else: + array = np.concatenate(arrays, axis=concat_axis) + + if return_subtype: + return array, rate, subtypes + else: + return array, rate + class SoundScpReader(collections.abc.Mapping): """Reader class for 'wav.scp'. diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index 5f20deec4..994607fb7 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -1081,7 +1081,10 @@ class ConformerChunkEncoder(AbsEncoder): mask = make_source_mask(x_len).to(x.device) if self.unified_model_training: - chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + if self.training: + chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + else: + chunk_size = self.default_chunk_size x, mask = self.embed(x, mask, chunk_size) pos_enc = self.pos_enc(x) chunk_mask = make_chunk_mask( @@ -1113,12 +1116,15 @@ class ConformerChunkEncoder(AbsEncoder): elif self.dynamic_chunk_training: max_len = x.size(1) - chunk_size = torch.randint(1, max_len, (1,)).item() + if self.training: + chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - chunk_size = max_len + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = (chunk_size % self.short_chunk_size) + 1 else: - chunk_size = (chunk_size % self.short_chunk_size) + 1 + chunk_size = self.default_chunk_size x, mask = self.embed(x, mask, chunk_size) pos_enc = self.pos_enc(x) @@ -1147,6 +1153,45 @@ class ConformerChunkEncoder(AbsEncoder): return x, olens, None + def full_utt_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Encoder input features. (B, T_in, F) + x_len: Encoder input features lengths. (B,) + Returns: + x: Encoder outputs. (B, T_out, D_enc) + x_len: Encoder outputs lenghts. (B,) + """ + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len).to(x.device) + x, mask = self.embed(x, mask, None) + pos_enc = self.pos_enc(x) + x_utt = self.encoders( + x, + pos_enc, + mask, + chunk_mask=None, + ) + + if self.time_reduction_factor > 1: + x_utt = x_utt[:,::self.time_reduction_factor,:] + return x_utt + def simu_chunk_forward( self, x: torch.Tensor, diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py index a2b91a7f8..77aa422a2 100644 --- a/funasr/modules/subsampling.py +++ b/funasr/modules/subsampling.py @@ -427,6 +427,7 @@ class StreamingConvInput(torch.nn.Module): conv_size: Union[int, Tuple], subsampling_factor: int = 4, vgg_like: bool = True, + conv_kernel_size: int = 3, output_size: Optional[int] = None, ) -> None: """Construct a ConvInput object.""" @@ -436,14 +437,14 @@ class StreamingConvInput(torch.nn.Module): conv_size1, conv_size2 = conv_size self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), - torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), torch.nn.MaxPool2d((1, 2)), - torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), - torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), torch.nn.MaxPool2d((1, 2)), ) @@ -462,14 +463,14 @@ class StreamingConvInput(torch.nn.Module): kernel_1 = int(subsampling_factor / 2) self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), - torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), torch.nn.MaxPool2d((kernel_1, 2)), - torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), - torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), torch.nn.ReLU(), torch.nn.MaxPool2d((2, 2)), ) @@ -487,14 +488,14 @@ class StreamingConvInput(torch.nn.Module): self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), torch.nn.ReLU(), - torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), + torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]), torch.nn.ReLU(), ) output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) self.subsampling_factor = subsampling_factor - self.kernel_2 = 3 + self.kernel_2 = conv_kernel_size self.stride_2 = 1 self.create_new_mask = self.create_new_conv2d_mask