mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf' of github.com:alibaba-damo-academy/FunASR into dev_gzf
add
This commit is contained in:
commit
5c8af9c7e5
19
examples/common_voice/whisper_lid/demo_funasr.py
Normal file
19
examples/common_voice/whisper_lid/demo_funasr.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
multilingual_wavs = [
|
||||
"example_zh-CN.mp3",
|
||||
"example_en.mp3",
|
||||
"example_ja.mp3",
|
||||
"example_ko.mp3",
|
||||
]
|
||||
|
||||
model = AutoModel(model="iic/speech_whisper-large_lid_multilingual_pytorch", model_revision="v2.0.4")
|
||||
for wav_id in multilingual_wavs:
|
||||
wav_file = f"{model.model_path}/examples/{wav_id}"
|
||||
res = model.generate(input=wav_file, data_type="sound", inference_clip_length=250)
|
||||
print("detect sample {}: {}".format(wav_id, res))
|
||||
22
examples/common_voice/whisper_lid/demo_modelscope.py
Normal file
22
examples/common_voice/whisper_lid/demo_modelscope.py
Normal file
@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
multilingual_wavs=[
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_zh-CN.mp3",
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_en.mp3",
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ja.mp3",
|
||||
"https://www.modelscope.cn/api/v1/models/iic/speech_whisper-large_lid_multilingual_pytorch/repo?Revision=master&FilePath=examples/example_ko.mp3",
|
||||
]
|
||||
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model='iic/speech_whisper-large_lid_multilingual_pytorch', model_revision="v2.0.4")
|
||||
|
||||
for wav in multilingual_wavs:
|
||||
rec_result = inference_pipeline(input=wav, inference_clip_length=250)
|
||||
print(rec_result)
|
||||
102
funasr/frontends/whisper_frontend.py
Normal file
102
funasr/frontends/whisper_frontend.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
|
||||
from funasr.register import tables
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
@tables.register("frontend_classes", "WhisperFrontend")
|
||||
class WhisperFrontend(nn.Module):
|
||||
"""Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
|
||||
|
||||
URL: https://github.com/openai/whisper
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
whisper_model: str = "large-v3",
|
||||
do_pad_trim: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
assert fs == 16000
|
||||
self.fs = fs
|
||||
|
||||
self.n_fft = N_FFT
|
||||
self.win_length = N_FFT
|
||||
self.hop_length = HOP_LENGTH
|
||||
self.pad_samples = N_SAMPLES
|
||||
self.frame_shift = self.hop_length
|
||||
self.lfr_n = 1
|
||||
if whisper_model == "large-v3" or whisper_model == "large":
|
||||
self.n_mels = 128
|
||||
else:
|
||||
self.n_mels = 80
|
||||
|
||||
self.mel_filters = whisper.audio.mel_filters
|
||||
self.do_pad_trim = do_pad_trim
|
||||
if do_pad_trim:
|
||||
self.pad_or_trim = whisper.pad_or_trim
|
||||
|
||||
assert whisper_model in whisper.available_models()
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def log_mel_spectrogram(
|
||||
self,
|
||||
audio: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
window = torch.hann_window(self.win_length).to(audio.device)
|
||||
stft = torch.stft(
|
||||
audio, self.n_fft, self.hop_length, window=window, return_complex=True
|
||||
)
|
||||
|
||||
# whisper deletes the last frame by default (Shih-Lun)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = self.mel_filters(audio.device, self.n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
|
||||
if ilens is not None:
|
||||
olens = ilens // self.hop_length
|
||||
else:
|
||||
olens = None
|
||||
|
||||
log_spec = torch.maximum(
|
||||
log_spec,
|
||||
log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
|
||||
)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
return log_spec, olens
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
if self.do_pad_trim:
|
||||
feat = self.pad_or_trim(input[i], self.pad_samples)
|
||||
else:
|
||||
feat = input[i]
|
||||
feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
|
||||
feats.append(feat[0])
|
||||
feats_lens.append(feat_len)
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
|
||||
if batch_size == 1:
|
||||
feats_pad = feats[0][None, :, :]
|
||||
else:
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
|
||||
return feats_pad, feats_lens
|
||||
0
funasr/models/whisper_lid/__init__.py
Normal file
0
funasr/models/whisper_lid/__init__.py
Normal file
167
funasr/models/whisper_lid/decoder.py
Normal file
167
funasr/models/whisper_lid/decoder.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import copy
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import whisper
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
|
||||
class OpenAIWhisperDecoderWarp(nn.Module):
|
||||
"""Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
|
||||
|
||||
URL: https://github.com/openai/whisper
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout_rate: float = 0.0,
|
||||
whisper_model: str = "small",
|
||||
download_dir: str = None,
|
||||
use_padmask: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert whisper_model in whisper.available_models()
|
||||
_model = whisper.load_model(
|
||||
whisper_model, download_root=download_dir, device="cpu"
|
||||
)
|
||||
self.decoders = copy.deepcopy(_model.decoder)
|
||||
attention_dim = self.decoders.token_embedding.embedding_dim
|
||||
|
||||
# note that originally Whisper doesn't use dropouts
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
self.decoders.train()
|
||||
del _model
|
||||
self.use_padmask = use_padmask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt, memory = ys_in_pad, hs_pad
|
||||
tgt = (
|
||||
self.decoders.token_embedding(tgt)
|
||||
+ self.decoders.positional_embedding[: tgt.size(1)]
|
||||
)
|
||||
tgt = self.dropout(tgt)
|
||||
|
||||
x = tgt.to(memory.dtype)
|
||||
|
||||
if self.use_padmask:
|
||||
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
||||
else:
|
||||
memory_mask = None
|
||||
|
||||
for layer, block in enumerate(self.decoders.blocks):
|
||||
x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
|
||||
|
||||
if layer < len(self.decoders.blocks) - 1:
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.decoders.ln(x)
|
||||
x = (
|
||||
x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return x, ys_in_lens
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
|
||||
Args:
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
NOTE (Shih-Lun):
|
||||
cache implementation is ignored for now
|
||||
for simplicity & correctness
|
||||
"""
|
||||
x = (
|
||||
self.decoders.token_embedding(tgt)
|
||||
+ self.decoders.positional_embedding[: tgt.size(1)]
|
||||
)
|
||||
x = self.dropout(x)
|
||||
x = x.to(memory.dtype)
|
||||
|
||||
for layer, block in enumerate(self.decoders.blocks):
|
||||
x = block(x, memory, mask=self.decoders.mask)
|
||||
if layer < len(self.decoders.blocks) - 1:
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.decoders.ln(x)
|
||||
y = x[:, -1]
|
||||
y = (
|
||||
y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
y = torch.log_softmax(y, dim=-1)
|
||||
|
||||
return y, None
|
||||
|
||||
def score(self, ys, state, x):
|
||||
"""Score."""
|
||||
logp, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask
|
||||
)
|
||||
return logp.squeeze(0), state
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# batch decoding, dummy mask is passed
|
||||
logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
|
||||
|
||||
return logp, None
|
||||
119
funasr/models/whisper_lid/encoder.py
Normal file
119
funasr/models/whisper_lid/encoder.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import copy
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
|
||||
class OpenAIWhisperEncoderWarp(nn.Module):
|
||||
"""Transformer-based Speech Encoder from OpenAI's Whisper Model:
|
||||
|
||||
URL: https://github.com/openai/whisper
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout_rate: float = 0.0,
|
||||
whisper_model: str = "small",
|
||||
download_dir: str = None,
|
||||
use_specaug: bool = False,
|
||||
use_padmask: bool = False,
|
||||
specaug_conf: Union[dict, None] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# note that originally Whisper doesn't use dropouts
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
assert whisper_model in whisper.available_models()
|
||||
_model = whisper.load_model(
|
||||
whisper_model, download_root=download_dir, device="cpu"
|
||||
)
|
||||
self.encoders = copy.deepcopy(_model.encoder)
|
||||
self.encoders.train()
|
||||
|
||||
del _model
|
||||
|
||||
if use_specaug:
|
||||
self.specaug = SpecAug(**specaug_conf)
|
||||
else:
|
||||
self.specaug = None
|
||||
self.use_padmask = use_padmask
|
||||
|
||||
def whisper_encode(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
x = F.gelu(self.encoders.conv1(input))
|
||||
x = F.gelu(self.encoders.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
n_frames = x.size(1)
|
||||
max_pos = self.encoders.positional_embedding.size(0)
|
||||
if n_frames <= max_pos:
|
||||
x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
|
||||
else:
|
||||
# due to positional encoding, audios >30 sec won't be accepted
|
||||
x = x[:, :max_pos, :] + self.encoders.positional_embedding
|
||||
|
||||
if ilens is not None:
|
||||
olens = (
|
||||
1
|
||||
+ (
|
||||
ilens
|
||||
- self.encoders.conv2.kernel_size[0]
|
||||
+ 2 * self.encoders.conv2.padding[0]
|
||||
)
|
||||
// self.encoders.conv2.stride[0]
|
||||
)
|
||||
olens = torch.clamp(olens, max=max_pos)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
if self.use_padmask:
|
||||
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
for layer, block in enumerate(self.encoders.blocks):
|
||||
x = block(x)
|
||||
if layer < len(self.encoders.blocks) - 1:
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.encoders.ln_post(x)
|
||||
|
||||
return x, olens
|
||||
|
||||
def output_size(self) -> int:
|
||||
# dummy output size
|
||||
return self.encoders.conv2.weight.shape[0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
feats, feats_lens = xs_pad, ilens
|
||||
|
||||
if self.specaug is not None and self.encoders.training:
|
||||
feats = torch.transpose(feats, 1, 2)
|
||||
feats, feats_lens = self.specaug(feats, feats_lens)
|
||||
feats = torch.transpose(feats, 1, 2)
|
||||
|
||||
xs_pad, olens = self.whisper_encode(feats, feats_lens)
|
||||
|
||||
return xs_pad, olens, None
|
||||
428
funasr/models/whisper_lid/eres2net/ResNet.py
Normal file
428
funasr/models/whisper_lid/eres2net/ResNet.py
Normal file
@ -0,0 +1,428 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
ERes2Net-Large is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import funasr.models.whisper_lid.eres2net.pooling_layers as pooling_layers
|
||||
from funasr.models.whisper_lid.eres2net.fusion import AFF
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"1x1 convolution without padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, bias=False)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(conv3x3(width, width))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(conv3x3(width, width))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
self._output_size = embedding_size
|
||||
|
||||
self.conv1 = nn.Conv2d(1,
|
||||
m_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(block,
|
||||
m_channels,
|
||||
num_blocks[0],
|
||||
stride=1)
|
||||
self.layer2 = self._make_layer(block,
|
||||
m_channels * 2,
|
||||
num_blocks[1],
|
||||
stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse,
|
||||
m_channels * 4,
|
||||
num_blocks[2],
|
||||
stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse,
|
||||
m_channels * 8,
|
||||
num_blocks[3],
|
||||
stride=2)
|
||||
|
||||
# Downsampling module for each layer
|
||||
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
|
||||
bias=False)
|
||||
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
|
||||
bias=False)
|
||||
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
|
||||
bias=False)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||
embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(self, x, ilens):
|
||||
# assert x.shape[1] == ilens.max()
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||
olens = (((((ilens - 1) // 2 + 1) - 1) // 2 + 1) - 1) // 2 + 1
|
||||
stats = self.pool(fuse_out1234, olens)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
|
||||
class BasicBlockRes2Net(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockRes2Net, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = conv1x1(in_planes, width * scale, stride)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale - 1
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(conv3x3(width, width))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = conv1x1(width * scale, planes * self.expansion)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = torch.cat((out, spx[self.nums]), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Res2Net(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockRes2Net,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(Res2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
|
||||
self.conv1 = nn.Conv2d(1,
|
||||
m_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(block,
|
||||
m_channels,
|
||||
num_blocks[0],
|
||||
stride=1)
|
||||
self.layer2 = self._make_layer(block,
|
||||
m_channels * 2,
|
||||
num_blocks[1],
|
||||
stride=2)
|
||||
self.layer3 = self._make_layer(block,
|
||||
m_channels * 4,
|
||||
num_blocks[2],
|
||||
stride=2)
|
||||
self.layer4 = self._make_layer(block,
|
||||
m_channels * 8,
|
||||
num_blocks[3],
|
||||
stride=2)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||
embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
|
||||
stats = self.pool(out)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
|
||||
|
||||
|
||||
0
funasr/models/whisper_lid/eres2net/__init__.py
Normal file
0
funasr/models/whisper_lid/eres2net/__init__.py
Normal file
29
funasr/models/whisper_lid/eres2net/fusion.py
Normal file
29
funasr/models/whisper_lid/eres2net/fusion.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AFF(nn.Module):
|
||||
|
||||
def __init__(self, channels=64, r=4):
|
||||
super(AFF, self).__init__()
|
||||
inter_channels = int(channels // r)
|
||||
|
||||
self.local_att = nn.Sequential(
|
||||
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
||||
nn.BatchNorm2d(inter_channels),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
|
||||
def forward(self, x, ds_y):
|
||||
xa = torch.cat((x, ds_y), dim=1)
|
||||
x_att = self.local_att(xa)
|
||||
x_att = 1.0 + torch.tanh(x_att)
|
||||
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
|
||||
|
||||
return xo
|
||||
|
||||
118
funasr/models/whisper_lid/eres2net/pooling_layers.py
Normal file
118
funasr/models/whisper_lid/eres2net/pooling_layers.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class TAP(nn.Module):
|
||||
"""
|
||||
Temporal average pooling, only first-order mean is considered
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(TAP, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
pooling_mean = x.mean(dim=-1)
|
||||
# To be compatable with 2D input
|
||||
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||||
return pooling_mean
|
||||
|
||||
|
||||
class TSDP(nn.Module):
|
||||
"""
|
||||
Temporal standard deviation pooling, only second-order std is considered
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(TSDP, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# The last dimension is the temporal axis
|
||||
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
||||
pooling_std = pooling_std.flatten(start_dim=1)
|
||||
return pooling_std
|
||||
|
||||
|
||||
class TSTP(nn.Module):
|
||||
"""
|
||||
Temporal statistics pooling, concatenate mean and std, which is used in
|
||||
x-vector
|
||||
Comment: simple concatenation can not make full use of both statistics
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(TSTP, self).__init__()
|
||||
|
||||
def forward(self, x, olens):
|
||||
# The last dimension is the temporal axis
|
||||
masks = (~make_pad_mask(olens, maxlen=x.shape[-1])[:, None, None, :]).to(x.device)
|
||||
x_masked = x * masks
|
||||
sum_without_padding = torch.sum(x_masked, axis=-1)
|
||||
count_without_padding = torch.sum(masks, axis=-1)
|
||||
mean_without_padding = sum_without_padding / count_without_padding
|
||||
|
||||
var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding
|
||||
|
||||
pooling_mean = mean_without_padding
|
||||
pooling_std = torch.sqrt(var_without_padding + 1e-8)
|
||||
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||||
pooling_std = pooling_std.flatten(start_dim=1)
|
||||
|
||||
stats = torch.cat((pooling_mean, pooling_std), 1)
|
||||
return stats
|
||||
|
||||
|
||||
class ASTP(nn.Module):
|
||||
""" Attentive statistics pooling: Channel- and context-dependent
|
||||
statistics pooling, first used in ECAPA_TDNN.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
||||
super(ASTP, self).__init__()
|
||||
self.global_context_att = global_context_att
|
||||
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
||||
# need to transpose inputs.
|
||||
if global_context_att:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim * 3, bottleneck_dim,
|
||||
kernel_size=1) # equals W and b in the paper
|
||||
else:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim, bottleneck_dim,
|
||||
kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
||||
kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
||||
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
||||
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
||||
"""
|
||||
if len(x.shape) == 4:
|
||||
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
||||
assert len(x.shape) == 3
|
||||
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(
|
||||
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
|
||||
# DON'T use ReLU here! ReLU may be hard to converge.
|
||||
alpha = torch.tanh(
|
||||
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
|
||||
std = torch.sqrt(var.clamp(min=1e-10))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
17
funasr/models/whisper_lid/eres2net/simple_avg.py
Normal file
17
funasr/models/whisper_lid/eres2net/simple_avg.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
class SimpleAvg(AbsEncoder):
|
||||
def __init__(self, feat_dim):
|
||||
super(SimpleAvg, self).__init__()
|
||||
self.feat_dim = feat_dim
|
||||
|
||||
def forward(self, x, ilens):
|
||||
mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
|
||||
avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
|
||||
return avg_x
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.feat_dim
|
||||
25
funasr/models/whisper_lid/lid_predictor.py
Normal file
25
funasr/models/whisper_lid/lid_predictor.py
Normal file
@ -0,0 +1,25 @@
|
||||
from funasr.register import tables
|
||||
from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
|
||||
|
||||
|
||||
@tables.register("lid_predictor_classes", "LidPredictor")
|
||||
class LidPredictor(ERes2Net):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(LidPredictor, self).__init__(
|
||||
block=block,
|
||||
block_fuse=block_fuse,
|
||||
num_blocks=num_blocks,
|
||||
m_channels=m_channels,
|
||||
feat_dim=feat_dim,
|
||||
embedding_size=embedding_size,
|
||||
pooling_func=pooling_func,
|
||||
two_emb_layer=two_emb_layer
|
||||
)
|
||||
665
funasr/models/whisper_lid/model.py
Normal file
665
funasr/models/whisper_lid/model.py
Normal file
@ -0,0 +1,665 @@
|
||||
import logging
|
||||
from typing import Union, Dict, List, Tuple, Optional
|
||||
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
||||
from funasr.metrics.compute_acc import th_accuracy
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.utils import postprocess_utils
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("model_classes", "OpenAIWhisperModel")
|
||||
class OpenAIWhisperModel(nn.Module):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
# extract_feats_in_collect_stats: bool = True,
|
||||
share_embedding: bool = False,
|
||||
# preencoder: Optional[AbsPreEncoder] = None,
|
||||
# postencoder: Optional[AbsPostEncoder] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
if decoder is not None:
|
||||
decoder_class = tables.decoder_classes.get(decoder)
|
||||
decoder = decoder_class(decoder_conf)
|
||||
if ctc_weight > 0.0:
|
||||
|
||||
if ctc_conf is None:
|
||||
ctc_conf = {}
|
||||
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.sos = sos if sos is not None else vocab_size - 1
|
||||
self.eos = eos if eos is not None else vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.encoder = encoder
|
||||
|
||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
self.encoder.conditioning_layer = torch.nn.Linear(
|
||||
vocab_size, self.encoder.output_size()
|
||||
)
|
||||
self.interctc_weight = interctc_weight
|
||||
|
||||
# self.error_calculator = None
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
#
|
||||
# if report_cer or report_wer:
|
||||
# self.error_calculator = ErrorCalculator(
|
||||
# token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
# )
|
||||
#
|
||||
self.error_calculator = None
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
|
||||
self.share_embedding = share_embedding
|
||||
if self.share_embedding:
|
||||
self.decoder.embed = None
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict()
|
||||
|
||||
# decoder: CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# Collect CTC branch stats
|
||||
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||
stats["cer_ctc"] = cer_ctc
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
# decoder: Attention decoder branch
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
ind: int
|
||||
"""
|
||||
with autocast(False):
|
||||
|
||||
# Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
||||
|
||||
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
||||
|
||||
# Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
||||
speech, speech_lengths, ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def init_beam_search(self,
|
||||
**kwargs,
|
||||
):
|
||||
from funasr.models.transformer.search import BeamSearch
|
||||
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
|
||||
if self.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = kwargs.get("token_list")
|
||||
scorers.update(
|
||||
decoder=self.decoder,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
|
||||
ctc=kwargs.get("decoding_ctc_weight", 0.5),
|
||||
lm=kwargs.get("lm_weight", 0.0),
|
||||
ngram=kwargs.get("ngram_weight", 0.0),
|
||||
length_bonus=kwargs.get("penalty", 0.0),
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=kwargs.get("beam_size", 10),
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=self.sos,
|
||||
eos=self.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
self.beam_search = beam_search
|
||||
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list=None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
# init beamsearch
|
||||
if self.beam_search is None:
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
meta_data = {}
|
||||
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
|
||||
results = []
|
||||
b, n, d = encoder_out.size()
|
||||
for i in range(b):
|
||||
|
||||
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
result_i = {"key": key[i], "token": token, "text": text_postprocessed}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
|
||||
return results, meta_data
|
||||
|
||||
|
||||
@tables.register("model_classes", "OpenAIWhisperLIDModel")
|
||||
class OpenAIWhisperLIDModel(nn.Module):
|
||||
"""WhisperEncoder and EResNet based LID Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: dict = None,
|
||||
lid_predictor: str = None,
|
||||
lid_predictor_conf: dict = None,
|
||||
proj_dim: int = None,
|
||||
clip_frames: int = None,
|
||||
random_clip: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(**encoder_conf)
|
||||
lid_predictor_class = tables.lid_predictor_classes.get(lid_predictor)
|
||||
lid_predictor = lid_predictor_class(**lid_predictor_conf)
|
||||
if encoder.output_size() != proj_dim:
|
||||
self.proj_layer = torch.nn.Linear(encoder.output_size(), proj_dim)
|
||||
else:
|
||||
self.proj_layer = None
|
||||
self.output_layer = torch.nn.Linear(lid_predictor.output_size(), vocab_size)
|
||||
self.criterion_lid = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=-1,
|
||||
smoothing=0.0,
|
||||
normalize_length=False,
|
||||
)
|
||||
|
||||
self.specaug = specaug
|
||||
self.encoder = encoder
|
||||
self.lid_predictor = lid_predictor
|
||||
self.clip_frames = clip_frames
|
||||
self.random_clip = random_clip
|
||||
self.normalize = None
|
||||
self.beam_search = None
|
||||
if not hasattr(self.encoder, "interctc_use_conditioning"):
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
|
||||
def forward(self,
|
||||
speech: torch.Tensor, # may be padding
|
||||
speech_lengths: torch.Tensor, # actual length
|
||||
lid: torch.Tensor, # lid label, (batch_size, 1)
|
||||
lid_lengths: torch.Tensor,
|
||||
):
|
||||
assert lid.shape[1] == 1
|
||||
batch_size = speech.shape[0]
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# re-generate encoder_out
|
||||
if self.clip_frames is None:
|
||||
reduced_encoder_out = torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
|
||||
for i, enc_length in enumerate(encoder_out_lens):
|
||||
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
|
||||
else:
|
||||
reduced_encoder_out = torch.zeros(batch_size, self.clip_frames, encoder_out.shape[-1]).to(encoder_out.dtype).to(encoder_out.device)
|
||||
if self.random_clip:
|
||||
for i, enc_length in enumerate(encoder_out_lens):
|
||||
if enc_length <= self.clip_frames:
|
||||
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
|
||||
encoder_out_lens[i] = enc_length
|
||||
else:
|
||||
max_start_index = enc_length.item() - self.clip_frames
|
||||
start_index = np.random.randint(0, max_start_index + 1)
|
||||
reduced_encoder_out[i, :self.clip_frames] = encoder_out[i, start_index:start_index + self.clip_frames]
|
||||
encoder_out_lens[i] = self.clip_frames
|
||||
else:
|
||||
for i, enc_length in enumerate(encoder_out_lens):
|
||||
enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
|
||||
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
|
||||
encoder_out_lens[i] = enc_length
|
||||
if self.proj_layer is not None:
|
||||
reduced_encoder_out = self.proj_layer(reduced_encoder_out)
|
||||
lid_output = self.lid_predictor(reduced_encoder_out, encoder_out_lens) # (B, D)
|
||||
lid_logits = self.output_layer(lid_output) # (B, num_classes)
|
||||
loss = self.criterion_lid(lid_logits[:, None, :], lid)
|
||||
with torch.no_grad():
|
||||
_, predicted_lid = torch.max(lid_logits, 1)
|
||||
correct = (predicted_lid == lid[:, 0]).sum().item()
|
||||
lid_acc = correct * 1.0 / lid_logits.shape[0]
|
||||
stats = dict()
|
||||
stats["batch_size"] = batch_size
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
stats["acc"] = lid_acc
|
||||
stats["token_length"] = speech_lengths.max()
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
|
||||
# Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
speech = speech.permute(0, 2, 1)
|
||||
# suit for whisper padding
|
||||
padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
|
||||
speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
|
||||
speech = speech.permute(0, 2, 1)
|
||||
|
||||
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
||||
|
||||
# Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(
|
||||
speech, speech_lengths, ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if kwargs.get("batch_size", 1) > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
meta_data = {}
|
||||
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
# Encoder
|
||||
enc, enc_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
inference_clip_length = kwargs.get("inference_clip_length", None)
|
||||
if self.clip_frames is not None:
|
||||
if inference_clip_length is None:
|
||||
reduced_enc = torch.zeros(enc.shape[0], self.clip_frames, enc.shape[-1]).to(enc.dtype).to(enc.device)
|
||||
for i, enc_length in enumerate(enc_out_lens):
|
||||
enc_length = self.clip_frames if enc_length >= self.clip_frames else enc_length
|
||||
reduced_enc[i, :enc_length] = enc[i, :enc_length]
|
||||
enc_out_lens[i] = enc_length
|
||||
else:
|
||||
assert inference_clip_length > 0, "inference_clip_length must be larger than 0"
|
||||
reduced_enc = torch.zeros(enc.shape[0], inference_clip_length, enc.shape[-1]).to(enc.dtype).to(enc.device)
|
||||
for i, enc_length in enumerate(enc_out_lens):
|
||||
enc_length = inference_clip_length if enc_length >= inference_clip_length else enc_length
|
||||
reduced_enc[i, :enc_length] = enc[i, :enc_length]
|
||||
enc_out_lens[i] = enc_length
|
||||
else:
|
||||
reduced_enc = torch.zeros(enc.shape[0], enc_out_lens.max(), enc.shape[-1]).to(enc.dtype).to(enc.device)
|
||||
for i, enc_length in enumerate(enc_out_lens):
|
||||
reduced_enc[i, :enc_length] = enc[i, :enc_length]
|
||||
|
||||
if self.proj_layer is not None:
|
||||
reduced_enc = self.proj_layer(reduced_enc)
|
||||
lid_output = self.lid_predictor(reduced_enc, enc_out_lens) # (B, D)
|
||||
lid_logits = self.output_layer(lid_output) # (B, num_classes)
|
||||
|
||||
_, predicted_lid_index = torch.max(lid_logits, 1)
|
||||
predicted_lid = tokenizer.ids2tokens([predicted_lid_index[0].cpu()])[0]
|
||||
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
lid_writer = self.writer["lid"]
|
||||
lid_writer[key[0]] = predicted_lid
|
||||
|
||||
results = [{"key": key[0], "lid": predicted_lid}]
|
||||
|
||||
return results, meta_data
|
||||
@ -1380,7 +1380,7 @@ sampleClientRun(){
|
||||
run_cmd="${client_exec} --server-ip ${server_ip} --port ${host_port} --wav-path ${wav_path}"
|
||||
;;
|
||||
Python)
|
||||
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/wss_client_asr.py"
|
||||
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/funasr_wss_client.py"
|
||||
run_cmd="python3 ${client_exec} --host ${server_ip} --port ${host_port} --mode offline --audio_in ${wav_path} --send_without_sleep --output_dir ${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python"
|
||||
pre_cmd="pip3 install click>=8.0.4"
|
||||
echo -e " Run ${BLUE}${pre_cmd}${PLAIN}"
|
||||
|
||||
@ -1455,7 +1455,7 @@ sampleClientRun(){
|
||||
run_cmd="${client_exec} --server-ip ${server_ip} --port ${host_port} --wav-path ${wav_path}"
|
||||
;;
|
||||
Python)
|
||||
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/wss_client_asr.py"
|
||||
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/funasr_wss_client.py"
|
||||
run_cmd="python3 ${client_exec} --host ${server_ip} --port ${host_port} --mode offline --audio_in ${wav_path} --send_without_sleep --output_dir ${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python"
|
||||
pre_cmd="pip3 install click>=8.0.4"
|
||||
echo -e " Run ${BLUE}${pre_cmd}${PLAIN}"
|
||||
|
||||
@ -1380,7 +1380,7 @@ sampleClientRun(){
|
||||
run_cmd="${client_exec} --server-ip ${server_ip} --port ${host_port} --wav-path ${wav_path}"
|
||||
;;
|
||||
Python)
|
||||
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/wss_client_asr.py"
|
||||
client_exec="${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python/funasr_wss_client.py"
|
||||
run_cmd="python3 ${client_exec} --host ${server_ip} --port ${host_port} --mode 2pass --audio_in ${wav_path} --send_without_sleep --output_dir ${PARAMS_FUNASR_SAMPLES_LOCAL_DIR}/python"
|
||||
pre_cmd="pip3 install click>=8.0.4"
|
||||
echo -e " Run ${BLUE}${pre_cmd}${PLAIN}"
|
||||
|
||||
@ -59,7 +59,7 @@ wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/sample/funasr_sa
|
||||
For illustration, we will use the Python language client, which supports audio formats (.wav, .pcm) and a multi-file list wav.scp input.
|
||||
|
||||
```shell
|
||||
python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass
|
||||
python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass
|
||||
```
|
||||
|
||||
------------------
|
||||
|
||||
@ -75,7 +75,7 @@ wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/sample/funasr_sa
|
||||
```
|
||||
我们以Python语言客户端为例,进行说明,支持音频格式(.wav, .pcm),以及多文件列表wav.scp输入,其他版本客户端请参考文档([点击此处](#客户端用法详解))。
|
||||
```shell
|
||||
python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass
|
||||
python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass
|
||||
```
|
||||
|
||||
------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user