diff --git a/examples/common_voice/whisper_lid/demo_funasr.py b/examples/common_voice/whisper_lid/demo_funasr.py new file mode 100644 index 000000000..9af790e1a --- /dev/null +++ b/examples/common_voice/whisper_lid/demo_funasr.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +multilingual_wavs = [ + "example_zh-CN.mp3", + "example_en.mp3", + "example_ja.mp3", + "example_ko.mp3", +] + +model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4") +for wav_id in multilingual_wavs: + wav_file = f"{model.model_path}/examples/{wav_id}" + res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250) + print("detect sample {}: {}".format(wav_id, res)) \ No newline at end of file diff --git a/examples/common_voice/whisper_lid/demo_modelscope.py b/examples/common_voice/whisper_lid/demo_modelscope.py new file mode 100644 index 000000000..cce389a1c --- /dev/null +++ b/examples/common_voice/whisper_lid/demo_modelscope.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +multilingual_wavs=[ + "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3", + "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3", + "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3", + "https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3", +] + +inference_pipeline = pipeline( + task=Tasks.auto_speech_recognition, + model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4") + +for wav in multilingual_wavs: + rec_result = inference_pipeline(input=wav, inference_clip_length=250) + print(rec_result) \ No newline at end of file diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py new file mode 100644 index 000000000..752fd20d2 --- /dev/null +++ b/funasr/frontends/whisper_frontend.py @@ -0,0 +1,102 @@ +from typing import Tuple +import torch +import torch.nn as nn +import whisper +from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES +from funasr.register import tables +from torch.nn.utils.rnn import pad_sequence + + +@tables.register("frontend_classes", "WhisperFrontend") +class WhisperFrontend(nn.Module): + """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model: + + URL: https://github.com/openai/whisper + """ + + def __init__( + self, + fs: int = 16000, + whisper_model: str = "large-v3", + do_pad_trim: bool = True, + ): + super().__init__() + assert fs == 16000 + self.fs = fs + + self.n_fft = N_FFT + self.win_length = N_FFT + self.hop_length = HOP_LENGTH + self.pad_samples = N_SAMPLES + self.frame_shift = self.hop_length + self.lfr_n = 1 + if whisper_model == "large-v3" or whisper_model == "large": + self.n_mels = 128 + else: + self.n_mels = 80 + + self.mel_filters = whisper.audio.mel_filters + self.do_pad_trim = do_pad_trim + if do_pad_trim: + self.pad_or_trim = whisper.pad_or_trim + + assert whisper_model in whisper.available_models() + + def output_size(self) -> int: + return self.n_mels + + def log_mel_spectrogram( + self, + audio: torch.Tensor, + ilens: torch.Tensor = None, + ) -> torch.Tensor: + window = torch.hann_window(self.win_length).to(audio.device) + stft = torch.stft( + audio, self.n_fft, self.hop_length, window=window, return_complex=True + ) + + # whisper deletes the last frame by default (Shih-Lun) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = self.mel_filters(audio.device, self.n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + + if ilens is not None: + olens = ilens // self.hop_length + else: + olens = None + + log_spec = torch.maximum( + log_spec, + log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0, + ) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec, olens + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + feats = [] + feats_lens = [] + for i in range(batch_size): + if self.do_pad_trim: + feat = self.pad_or_trim(input[i], self.pad_samples) + else: + feat = input[i] + feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0]) + feats.append(feat[0]) + feats_lens.append(feat_len) + feats_lens = torch.as_tensor(feats_lens) + + if batch_size == 1: + feats_pad = feats[0][None, :, :] + else: + feats_pad = pad_sequence(feats, + batch_first=True, + padding_value=0.0) + + return feats_pad, feats_lens \ No newline at end of file diff --git a/funasr/models/whisper_lid/__init__.py b/funasr/models/whisper_lid/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/whisper_lid/decoder.py b/funasr/models/whisper_lid/decoder.py new file mode 100644 index 000000000..4db920546 --- /dev/null +++ b/funasr/models/whisper_lid/decoder.py @@ -0,0 +1,167 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import copy +from typing import Any, List, Tuple + +import torch +from torch import nn +import whisper + +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.register import tables + + +@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp") +class OpenAIWhisperDecoderWarp(nn.Module): + """Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model: + + URL: https://github.com/openai/whisper + """ + + def __init__( + self, + dropout_rate: float = 0.0, + whisper_model: str = "small", + download_dir: str = None, + use_padmask: bool = False, + ): + super().__init__() + + assert whisper_model in whisper.available_models() + _model = whisper.load_model( + whisper_model, download_root=download_dir, device="cpu" + ) + self.decoders = copy.deepcopy(_model.decoder) + attention_dim = self.decoders.token_embedding.embedding_dim + + # note that originally Whisper doesn't use dropouts + self.dropout = torch.nn.Dropout(dropout_rate) + + self.decoders.train() + del _model + self.use_padmask = use_padmask + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt, memory = ys_in_pad, hs_pad + tgt = ( + self.decoders.token_embedding(tgt) + + self.decoders.positional_embedding[: tgt.size(1)] + ) + tgt = self.dropout(tgt) + + x = tgt.to(memory.dtype) + + if self.use_padmask: + memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) + else: + memory_mask = None + + for layer, block in enumerate(self.decoders.blocks): + x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) + + if layer < len(self.decoders.blocks) - 1: + x = self.dropout(x) + + x = self.decoders.ln(x) + x = ( + x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return x, ys_in_lens + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + + Args: + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + memory: encoded memory, float32 (batch, maxlen_in, feat) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + NOTE (Shih-Lun): + cache implementation is ignored for now + for simplicity & correctness + """ + x = ( + self.decoders.token_embedding(tgt) + + self.decoders.positional_embedding[: tgt.size(1)] + ) + x = self.dropout(x) + x = x.to(memory.dtype) + + for layer, block in enumerate(self.decoders.blocks): + x = block(x, memory, mask=self.decoders.mask) + if layer < len(self.decoders.blocks) - 1: + x = self.dropout(x) + + x = self.decoders.ln(x) + y = x[:, -1] + y = ( + y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + y = torch.log_softmax(y, dim=-1) + + return y, None + + def score(self, ys, state, x): + """Score.""" + logp, state = self.forward_one_step( + ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask + ) + return logp.squeeze(0), state + + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # batch decoding, dummy mask is passed + logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None) + + return logp, None diff --git a/funasr/models/whisper_lid/encoder.py b/funasr/models/whisper_lid/encoder.py new file mode 100644 index 000000000..7eeb643d6 --- /dev/null +++ b/funasr/models/whisper_lid/encoder.py @@ -0,0 +1,119 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import copy +from typing import Optional, Tuple, Union + +import torch +from torch import nn +import torch.nn.functional as F +import whisper + +from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.models.specaug.specaug import SpecAug +from funasr.register import tables + + +@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp") +class OpenAIWhisperEncoderWarp(nn.Module): + """Transformer-based Speech Encoder from OpenAI's Whisper Model: + + URL: https://github.com/openai/whisper + """ + + def __init__( + self, + dropout_rate: float = 0.0, + whisper_model: str = "small", + download_dir: str = None, + use_specaug: bool = False, + use_padmask: bool = False, + specaug_conf: Union[dict, None] = None, + ): + super().__init__() + + # note that originally Whisper doesn't use dropouts + self.dropout = torch.nn.Dropout(dropout_rate) + + assert whisper_model in whisper.available_models() + _model = whisper.load_model( + whisper_model, download_root=download_dir, device="cpu" + ) + self.encoders = copy.deepcopy(_model.encoder) + self.encoders.train() + + del _model + + if use_specaug: + self.specaug = SpecAug(**specaug_conf) + else: + self.specaug = None + self.use_padmask = use_padmask + + def whisper_encode( + self, + input: torch.Tensor, + ilens: torch.Tensor = None, + ) -> torch.Tensor: + x = F.gelu(self.encoders.conv1(input)) + x = F.gelu(self.encoders.conv2(x)) + x = x.permute(0, 2, 1) + + n_frames = x.size(1) + max_pos = self.encoders.positional_embedding.size(0) + if n_frames <= max_pos: + x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype) + else: + # due to positional encoding, audios >30 sec won't be accepted + x = x[:, :max_pos, :] + self.encoders.positional_embedding + + if ilens is not None: + olens = ( + 1 + + ( + ilens + - self.encoders.conv2.kernel_size[0] + + 2 * self.encoders.conv2.padding[0] + ) + // self.encoders.conv2.stride[0] + ) + olens = torch.clamp(olens, max=max_pos) + else: + olens = None + + if self.use_padmask: + padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device) + else: + padding_mask = None + + x = self.dropout(x) + + for layer, block in enumerate(self.encoders.blocks): + x = block(x) + if layer < len(self.encoders.blocks) - 1: + x = self.dropout(x) + + x = self.encoders.ln_post(x) + + return x, olens + + def output_size(self) -> int: + # dummy output size + return self.encoders.conv2.weight.shape[0] + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + feats, feats_lens = xs_pad, ilens + + if self.specaug is not None and self.encoders.training: + feats = torch.transpose(feats, 1, 2) + feats, feats_lens = self.specaug(feats, feats_lens) + feats = torch.transpose(feats, 1, 2) + + xs_pad, olens = self.whisper_encode(feats, feats_lens) + + return xs_pad, olens, None diff --git a/funasr/models/whisper_lid/eres2net/ResNet.py b/funasr/models/whisper_lid/eres2net/ResNet.py new file mode 100644 index 000000000..25c79f583 --- /dev/null +++ b/funasr/models/whisper_lid/eres2net/ResNet.py @@ -0,0 +1,428 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. + ERes2Net incorporates both local and global feature fusion techniques to improve the performance. + The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. + The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. + ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better + recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance. +""" + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers +from funasr.models.whisper_lid.eres2net.fusion import AFF + + +class ReLU(nn.Hardtanh): + + def __init__(self, inplace=False): + super(ReLU, self).__init__(0, 20, inplace) + + def __repr__(self): + inplace_str = 'inplace' if self.inplace else '' + return self.__class__.__name__ + ' (' \ + + inplace_str + ')' + + +def conv1x1(in_planes, out_planes, stride=1): + "1x1 convolution without padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlockERes2Net(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockERes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class BasicBlockERes2Net_diff_AFF(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockERes2Net_diff_AFF, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale + + convs = [] + fuse_models = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + for j in range(self.nums - 1): + fuse_models.append(AFF(channels=width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.fuse_models = nn.ModuleList(fuse_models) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = self.fuse_models[i - 1](sp, spx[i]) + + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class ERes2Net(nn.Module): + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=32, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(ERes2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + self._output_size = embedding_size + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block_fuse, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block_fuse, + m_channels * 8, + num_blocks[3], + stride=2) + + # Downsampling module for each layer + self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, + bias=False) + self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, + bias=False) + self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, + bias=False) + + # Bottom-up fusion module + self.fuse_mode12 = AFF(channels=m_channels * 4) + self.fuse_mode123 = AFF(channels=m_channels * 8) + self.fuse_mode1234 = AFF(channels=m_channels * 16) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def output_size(self) -> int: + return self._output_size + + def forward(self, x, ilens): + # assert x.shape[1] == ilens.max() + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out1 = self.layer1(out) + out2 = self.layer2(out1) + out1_downsample = self.layer1_downsample(out1) + fuse_out12 = self.fuse_mode12(out2, out1_downsample) + out3 = self.layer3(out2) + fuse_out12_downsample = self.layer2_downsample(fuse_out12) + fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) + out4 = self.layer4(out3) + fuse_out123_downsample = self.layer3_downsample(fuse_out123) + fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample) + olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1 + stats = self.pool(fuse_out1234, olens) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + +class BasicBlockRes2Net(nn.Module): + expansion = 2 + + def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): + super(BasicBlockRes2Net, self).__init__() + width = int(math.floor(planes * (baseWidth / 64.0))) + self.conv1 = conv1x1(in_planes, width * scale, stride) + self.bn1 = nn.BatchNorm2d(width * scale) + self.nums = scale - 1 + convs = [] + bns = [] + for i in range(self.nums): + convs.append(conv3x3(width, width)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + self.relu = ReLU(inplace=True) + + self.conv3 = conv1x1(width * scale, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + self.stride = stride + self.width = width + self.scale = scale + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = torch.cat((out, spx[self.nums]), 1) + + out = self.conv3(out) + out = self.bn3(out) + + residual = self.shortcut(x) + out += residual + out = self.relu(out) + + return out + + +class Res2Net(nn.Module): + def __init__(self, + block=BasicBlockRes2Net, + num_blocks=[3, 4, 6, 3], + m_channels=32, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(Res2Net, self).__init__() + self.in_planes = m_channels + self.feat_dim = feat_dim + self.embedding_size = embedding_size + self.stats_dim = int(feat_dim / 8) * m_channels * 8 + self.two_emb_layer = two_emb_layer + + self.conv1 = nn.Conv2d(1, + m_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + self.layer1 = self._make_layer(block, + m_channels, + num_blocks[0], + stride=1) + self.layer2 = self._make_layer(block, + m_channels * 2, + num_blocks[1], + stride=2) + self.layer3 = self._make_layer(block, + m_channels * 4, + num_blocks[2], + stride=2) + self.layer4 = self._make_layer(block, + m_channels * 8, + num_blocks[3], + stride=2) + + self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=self.stats_dim * block.expansion) + self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + embedding_size) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) + self.seg_2 = nn.Linear(embedding_size, embedding_size) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + + x = x.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + + stats = self.pool(out) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_b + else: + return embed_a + + + + diff --git a/funasr/models/whisper_lid/eres2net/__init__.py b/funasr/models/whisper_lid/eres2net/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/whisper_lid/eres2net/fusion.py b/funasr/models/whisper_lid/eres2net/fusion.py new file mode 100644 index 000000000..2aff7a721 --- /dev/null +++ b/funasr/models/whisper_lid/eres2net/fusion.py @@ -0,0 +1,29 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn as nn + + +class AFF(nn.Module): + + def __init__(self, channels=64, r=4): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + self.local_att = nn.Sequential( + nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.SiLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + def forward(self, x, ds_y): + xa = torch.cat((x, ds_y), dim=1) + x_att = self.local_att(xa) + x_att = 1.0 + torch.tanh(x_att) + xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att) + + return xo + diff --git a/funasr/models/whisper_lid/eres2net/pooling_layers.py b/funasr/models/whisper_lid/eres2net/pooling_layers.py new file mode 100644 index 000000000..f756ac89d --- /dev/null +++ b/funasr/models/whisper_lid/eres2net/pooling_layers.py @@ -0,0 +1,118 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker.""" + +import torch +import torch.nn as nn + +from funasr.models.transformer.utils.nets_utils import make_pad_mask + + +class TAP(nn.Module): + """ + Temporal average pooling, only first-order mean is considered + """ + + def __init__(self, **kwargs): + super(TAP, self).__init__() + + def forward(self, x): + pooling_mean = x.mean(dim=-1) + # To be compatable with 2D input + pooling_mean = pooling_mean.flatten(start_dim=1) + return pooling_mean + + +class TSDP(nn.Module): + """ + Temporal standard deviation pooling, only second-order std is considered + """ + + def __init__(self, **kwargs): + super(TSDP, self).__init__() + + def forward(self, x): + # The last dimension is the temporal axis + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_std = pooling_std.flatten(start_dim=1) + return pooling_std + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, **kwargs): + super(TSTP, self).__init__() + + def forward(self, x, olens): + # The last dimension is the temporal axis + masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device) + x_masked = x * masks + sum_without_padding = torch.sum(x_masked, axis=-1) + count_without_padding = torch.sum(masks, axis=-1) + mean_without_padding = sum_without_padding / count_without_padding + + var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding + + pooling_mean = mean_without_padding + pooling_std = torch.sqrt(var_without_padding + 1e-8) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + +class ASTP(nn.Module): + """ Attentive statistics pooling: Channel- and context-dependent + statistics pooling, first used in ECAPA_TDNN. + """ + + def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False): + super(ASTP, self).__init__() + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't + # need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, + kernel_size=1) # equals V and k in the paper + + def forward(self, x): + """ + x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) + or a 4-dimensional tensor in resnet architecture (B,C,F,T) + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(x.shape) == 4: + x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) + assert len(x.shape) == 3 + + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! ReLU may be hard to converge. + alpha = torch.tanh( + self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 + std = torch.sqrt(var.clamp(min=1e-10)) + return torch.cat([mean, std], dim=1) diff --git a/funasr/models/whisper_lid/eres2net/simple_avg.py b/funasr/models/whisper_lid/eres2net/simple_avg.py new file mode 100644 index 000000000..4fb4c0aa9 --- /dev/null +++ b/funasr/models/whisper_lid/eres2net/simple_avg.py @@ -0,0 +1,17 @@ +import torch + +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.modules.nets_utils import make_pad_mask + +class SimpleAvg(AbsEncoder): + def __init__(self, feat_dim): + super(SimpleAvg, self).__init__() + self.feat_dim = feat_dim + + def forward(self, x, ilens): + mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device) + avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None] + return avg_x + + def output_size(self) -> int: + return self.feat_dim \ No newline at end of file diff --git a/funasr/models/whisper_lid/lid_predictor.py b/funasr/models/whisper_lid/lid_predictor.py new file mode 100644 index 000000000..5e042d245 --- /dev/null +++ b/funasr/models/whisper_lid/lid_predictor.py @@ -0,0 +1,25 @@ +from funasr.register import tables +from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF + + +@tables.register("lid_predictor_classes", "LidPredictor") +class LidPredictor(ERes2Net): + def __init__(self, + block=BasicBlockERes2Net, + block_fuse=BasicBlockERes2Net_diff_AFF, + num_blocks=[3, 4, 6, 3], + m_channels=32, + feat_dim=80, + embedding_size=192, + pooling_func='TSTP', + two_emb_layer=False): + super(LidPredictor, self).__init__( + block=block, + block_fuse=block_fuse, + num_blocks=num_blocks, + m_channels=m_channels, + feat_dim=feat_dim, + embedding_size=embedding_size, + pooling_func=pooling_func, + two_emb_layer=two_emb_layer + ) \ No newline at end of file diff --git a/funasr/models/whisper_lid/model.py b/funasr/models/whisper_lid/model.py new file mode 100644 index 000000000..6ffb43a6a --- /dev/null +++ b/funasr/models/whisper_lid/model.py @@ -0,0 +1,665 @@ +import logging +from typing import Union, Dict, List, Tuple, Optional + +import time +import torch +import numpy as np +import torch.nn as nn +from torch.cuda.amp import autocast + +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.ctc.ctc import CTC +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.metrics.compute_acc import th_accuracy +from funasr.train_utils.device_funcs import force_gatherable +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter +from funasr.register import tables + + +@tables.register("model_classes", "OpenAIWhisperModel") +class OpenAIWhisperModel(nn.Module): + """CTC-attention hybrid Encoder-Decoder model""" + + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + encoder: str = None, + encoder_conf: dict = None, + decoder: str = None, + decoder_conf: dict = None, + ctc: str = None, + ctc_conf: dict = None, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + if decoder is not None: + decoder_class = tables.decoder_classes.get(decoder) + decoder = decoder_class(decoder_conf) + if ctc_weight > 0.0: + + if ctc_conf is None: + ctc_conf = {} + + ctc = CTC( + odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf + ) + + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.specaug = specaug + self.normalize = normalize + self.encoder = encoder + + if not hasattr(self.encoder, "interctc_use_conditioning"): + self.encoder.interctc_use_conditioning = False + if self.encoder.interctc_use_conditioning: + self.encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.encoder.output_size() + ) + self.interctc_weight = interctc_weight + + # self.error_calculator = None + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + # + # if report_cer or report_wer: + # self.error_calculator = ErrorCalculator( + # token_list, sym_space, sym_blank, report_cer, report_wer + # ) + # + self.error_calculator = None + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.share_embedding = share_embedding + if self.share_embedding: + self.decoder.embed = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # import pdb; + # pdb.set_trace() + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + loss_att, acc_att, cer_att, wer_att = None, None, None, None + loss_ctc, cer_ctc = None, None + stats = dict() + + # decoder: CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + # decoder: Attention decoder branch + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + + # Collect total loss stats + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + speech, speech_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def init_beam_search(self, + **kwargs, + ): + from funasr.models.transformer.search import BeamSearch + from funasr.models.transformer.scorers.ctc import CTCPrefixScorer + from funasr.models.transformer.scorers.length_bonus import LengthBonus + + # 1. Build ASR model + scorers = {} + + if self.ctc != None: + ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) + scorers.update( + ctc=ctc + ) + token_list = kwargs.get("token_list") + scorers.update( + decoder=self.decoder, + length_bonus=LengthBonus(len(token_list)), + ) + + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + weights = dict( + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5), + ctc=kwargs.get("decoding_ctc_weight", 0.5), + lm=kwargs.get("lm_weight", 0.0), + ngram=kwargs.get("ngram_weight", 0.0), + length_bonus=kwargs.get("penalty", 0.0), + ) + beam_search = BeamSearch( + beam_size=kwargs.get("beam_size", 10), + weights=weights, + scorers=scorers, + sos=self.sos, + eos=self.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", + ) + + self.beam_search = beam_search + + def inference(self, + data_in, + data_lengths=None, + key: list=None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + # init beamsearch + if self.beam_search is None: + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + + results = [] + b, n, d = encoder_out.size() + for i in range(b): + + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text_postprocessed} + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text_postprocessed + + return results, meta_data + + +@tables.register("model_classes", "OpenAIWhisperLIDModel") +class OpenAIWhisperLIDModel(nn.Module): + """WhisperEncoder and EResNet based LID Model""" + + def __init__( + self, + vocab_size: int, + specaug: str = None, + specaug_conf: dict = None, + encoder: str = None, + encoder_conf: dict = None, + lid_predictor: str = None, + lid_predictor_conf: dict = None, + proj_dim: int = None, + clip_frames: int = None, + random_clip: bool = False, + **kwargs, + ): + super().__init__() + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(**encoder_conf) + lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor) + lid_predictor = lid_predictor_class(**lid_predictor_conf) + if encoder.output_size() != proj_dim: + self.proj_layer = torch.nn.Linear(encoder.output_size(), proj_dim) + else: + self.proj_layer = None + self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size) + self.criterion_lid = LabelSmoothingLoss( + size=vocab_size, + padding_idx=-1, + smoothing=0.0, + normalize_length=False, + ) + + self.specaug = specaug + self.encoder = encoder + self.lid_predictor = lid_predictor + self.clip_frames = clip_frames + self.random_clip = random_clip + self.normalize = None + self.beam_search = None + if not hasattr(self.encoder, "interctc_use_conditioning"): + self.encoder.interctc_use_conditioning = False + + def forward(self, + speech: torch.Tensor, # may be padding + speech_lengths: torch.Tensor, # actual length + lid: torch.Tensor, # lid label, (batch_size, 1) + lid_lengths: torch.Tensor, + ): + assert lid.shape[1] == 1 + batch_size = speech.shape[0] + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # re-generate encoder_out + if self.clip_frames is None: + reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device) + for i, enc_length in enumerate(encoder_out_lens): + reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length] + else: + reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device) + if self.random_clip: + for i, enc_length in enumerate(encoder_out_lens): + if enc_length <= self.clip_frames: + reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length] + encoder_out_lens[i] = enc_length + else: + max_start_index = enc_length.item() - self.clip_frames + start_index = np.random.randint(0, max_start_index + 1) + reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames] + encoder_out_lens[i] = self.clip_frames + else: + for i, enc_length in enumerate(encoder_out_lens): + enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length + reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length] + encoder_out_lens[i] = enc_length + if self.proj_layer is not None: + reduced_encoder_out = self.proj_layer(reduced_encoder_out) + lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens) # (B, D) + lid_logits = self.output_layer(lid_output) # (B, num_classes) + loss = self.criterion_lid(lid_logits[:, None, :], lid) + with torch.no_grad(): + _, predicted_lid = torch.max(lid_logits, 1) + correct = (predicted_lid == lid[:, 0]).sum().item() + lid_acc = correct * 1.0 / lid_logits.shape[0] + stats = dict() + stats["batch_size"] = batch_size + stats["loss"] = torch.clone(loss.detach()) + stats["acc"] = lid_acc + stats["token_length"] = speech_lengths.max() + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech = speech.permute(0, 2, 1) + # suit for whisper padding + padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1] + speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths) + speech = speech.permute(0, 2, 1) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder( + speech, speech_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + + def inference(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + meta_data = {} + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + # Encoder + enc, enc_out_lens = self.encode(speech, speech_lengths) + + inference_clip_length = kwargs.get("inference_clip_length", None) + if self.clip_frames is not None: + if inference_clip_length is None: + reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device) + for i, enc_length in enumerate(enc_out_lens): + enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length + reduced_enc[i, :enc_length] = enc[i, :enc_length] + enc_out_lens[i] = enc_length + else: + assert inference_clip_length > 0, "inference_clip_length must be larger than 0" + reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device) + for i, enc_length in enumerate(enc_out_lens): + enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length + reduced_enc[i, :enc_length] = enc[i, :enc_length] + enc_out_lens[i] = enc_length + else: + reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device) + for i, enc_length in enumerate(enc_out_lens): + reduced_enc[i, :enc_length] = enc[i, :enc_length] + + if self.proj_layer is not None: + reduced_enc = self.proj_layer(reduced_enc) + lid_output = self.lid_predictor(reduced_enc, enc_out_lens) # (B, D) + lid_logits = self.output_layer(lid_output) # (B, num_classes) + + _, predicted_lid_index = torch.max(lid_logits, 1) + predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0] + + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + lid_writer = self.writer["lid"] + lid_writer[key[0]] = predicted_lid + + results = [{"key": key[0], "lid": predicted_lid}] + + return results, meta_data \ No newline at end of file