Merge pull request #328 from alibaba-damo-academy/dev_cmz2

Dev cmz2
This commit is contained in:
zhifu gao 2023-04-07 15:54:21 +08:00 committed by GitHub
commit 0cc7af2c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 2283 additions and 826 deletions

View File

@ -37,9 +37,6 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class

View File

@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:

View File

@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:

View File

@ -800,3 +800,17 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@ -167,31 +167,57 @@ class ModelExport:
def export(self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
mode: str = 'paraformer',
mode: str = None,
):
model_dir = tag_name
if model_dir.startswith('damo/'):
if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
asr_train_config = os.path.join(model_dir, 'config.yaml')
asr_model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
json_file = os.path.join(model_dir, 'configuration.json')
if mode is None:
import json
json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
mode = config_data['model']['model_config']['mode']
if config_data['task'] == "punctuation":
mode = config_data['model']['punc_model_config']['mode']
else:
mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
elif mode.startswith('uniasr'):
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
elif mode.startswith('offline'):
from funasr.tasks.vad import VADTask
config = os.path.join(model_dir, 'vad.yaml')
model_file = os.path.join(model_dir, 'vad.pb')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
model, vad_infer_args = VADTask.build_model_from_file(
config, model_file, cmvn_file=cmvn_file, device='cpu'
)
self.export_config["feats_dim"] = 400
self.frontend = model.frontend
elif mode.startswith('punc'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
elif mode.startswith('punc_VadRealtime'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
self._export(model, tag_name)

View File

@ -1,13 +1,25 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.target_delay_transformer import CT_Transformer as CT_Transformer_export
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.target_delay_transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config)
elif isinstance(model, PunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return CT_Transformer_export(model.punc_model, **export_config)
elif isinstance(model.punc_model, VadRealtimeTransformer):
return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
else:
raise "Funasr does not support the given model type currently."
raise "Funasr does not support the given model type currently."

View File

@ -0,0 +1,60 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import torch
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
class E2EVadModel(nn.Module):
def __init__(self, model,
max_seq_len=512,
feats_dim=400,
model_name='model',
**kwargs,):
super(E2EVadModel, self).__init__()
self.feats_dim = feats_dim
self.max_seq_len = max_seq_len
self.model_name = model_name
if isinstance(model.encoder, FSMN):
self.encoder = FSMN_export(model.encoder)
else:
raise "unsupported encoder"
def forward(self, feats: torch.Tensor, *args, ):
scores, out_caches = self.encoder(feats, *args)
return scores, out_caches
def get_dummy_inputs(self, frame=30):
speech = torch.randn(1, frame, self.feats_dim)
in_cache0 = torch.randn(1, 128, 19, 1)
in_cache1 = torch.randn(1, 128, 19, 1)
in_cache2 = torch.randn(1, 128, 19, 1)
in_cache3 = torch.randn(1, 128, 19, 1)
return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
# def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
# import numpy as np
# fbank = np.loadtxt(txt_file)
# fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
# speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
# speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
# return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
def get_output_names(self):
return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
def get_dynamic_axes(self):
return {
'speech': {
1: 'feats_length'
},
}

View File

@ -0,0 +1,296 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.encoder.fsmn_encoder import BasicBlock
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, input):
output = self.linear(input)
return output
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, input):
output = self.linear(input)
return output
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
return out
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
if self.rorder > 0:
self.conv_right = nn.Conv2d(
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
else:
self.conv_right = None
def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
cache = cache.to(x_per.device)
y_left = torch.cat((cache, x_per), dim=2)
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
if self.conv_right is not None:
# maybe need to check
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.conv_right(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output, cache
class BasicBlock_export(nn.Module):
def __init__(self,
model,
):
super(BasicBlock_export, self).__init__()
self.linear = model.linear
self.fsmn_block = model.fsmn_block
self.affine = model.affine
self.relu = model.relu
def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
x = self.linear(input) # B T D
# cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
# if cache_layer_name not in in_cache:
# in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x, out_cache = self.fsmn_block(x, in_cache)
x = self.affine(x)
x = self.relu(x)
return x, out_cache
# class FsmnStack(nn.Sequential):
# def __init__(self, *args):
# super(FsmnStack, self).__init__(*args)
#
# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
# x = input
# for module in self._modules.values():
# x = module(x, in_cache)
# return x
'''
FSMN net for keyword spotting
input_dim: input dimension
linear_dim: fsmn input dimensionll
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
'''
class FSMN(nn.Module):
def __init__(
self, model,
):
super(FSMN, self).__init__()
# self.input_dim = input_dim
# self.input_affine_dim = input_affine_dim
# self.fsmn_layers = fsmn_layers
# self.linear_dim = linear_dim
# self.proj_dim = proj_dim
# self.output_affine_dim = output_affine_dim
# self.output_dim = output_dim
#
# self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
# self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
# self.relu = RectifiedLinear(linear_dim, linear_dim)
# self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
# range(fsmn_layers)])
# self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
# self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim=-1)
self.in_linear1 = model.in_linear1
self.in_linear2 = model.in_linear2
self.relu = model.relu
# self.fsmn = model.fsmn
self.out_linear1 = model.out_linear1
self.out_linear2 = model.out_linear2
self.softmax = model.softmax
self.fsmn = model.fsmn
for i, d in enumerate(model.fsmn):
if isinstance(d, BasicBlock):
self.fsmn[i] = BasicBlock_export(d)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
*args,
):
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x = self.in_linear1(input)
x = self.in_linear2(x)
x = self.relu(x)
# x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
out_caches = list()
for i, d in enumerate(self.fsmn):
in_cache = args[i]
x, out_cache = d(x, in_cache)
out_caches.append(out_cache)
x = self.out_linear1(x)
x = self.out_linear2(x)
x = self.softmax(x)
return x, out_caches
'''
one deep fsmn layer
dimproj: projection dimension, input and output dimension of memory blocks
dimlinear: dimension of mapping layer
lorder: left order
rorder: right order
lstride: left stride
rstride: right stride
'''
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
super(DFSMN, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.expand = AffineTransform(dimproj, dimlinear)
self.shrink = LinearTransform(dimlinear, dimproj)
self.conv_left = nn.Conv2d(
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
else:
self.conv_right = None
def forward(self, input):
f1 = F.relu(self.expand(input))
p1 = self.shrink(f1)
x = torch.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)
out1 = out.permute(0, 3, 2, 1)
output = input + out1.squeeze(1)
return output
'''
build stacked dfsmn layers
'''
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
repeats = [
nn.Sequential(
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@ -9,6 +9,7 @@ from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as Encod
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
class SANMEncoder(nn.Module):
def __init__(
self,
@ -107,3 +108,106 @@ class SANMEncoder(nn.Module):
}
}
class SANMVadEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
if hasattr(model, 'encoders0'):
for i, d in enumerate(self.model.encoders0):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders0[i] = EncoderLayerSANM_export(d)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerSANM_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask, sub_masks):
mask_3d_btd = mask[:, :, None]
mask_4d_bhlt = (1 - sub_masks) * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
vad_masks: torch.Tensor,
sub_masks: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
vad_masks = self.prepare_mask(mask, vad_masks)
mask = self.prepare_mask(mask, sub_masks)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders0(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
# encoder_outs = self.model.encoders(xs_pad, mask)
for layer_idx, encoder_layer in enumerate(self.model.encoders):
if layer_idx == len(self.model.encoders) - 1:
mask = vad_masks
encoder_outs = encoder_layer(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
# def get_dummy_inputs(self):
# feats = torch.randn(1, 100, self.feats_dim)
# return (feats)
#
# def get_input_names(self):
# return ['feats']
#
# def get_output_names(self):
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
#
# def get_dynamic_axes(self):
# return {
# 'feats': {
# 1: 'feats_length'
# },
# 'encoder_out': {
# 1: 'enc_out_length'
# },
# 'predictor_weight': {
# 1: 'pre_out_length'
# }
#
# }

View File

@ -0,0 +1,154 @@
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.models.encoder.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class CT_Transformer(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
# self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['inputs', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'inputs': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
class CT_Transformer_VadRealtime(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
if isinstance(model.encoder, SANMVadEncoder):
self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
self.decoder = model.decoder
self.model_name = model_name
def forward(self, inputs: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: torch.Tensor,
sub_masks: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
y = self.decoder(h)
return y
def with_vad(self):
return True
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def get_input_names(self):
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'inputs': {
1: 'feats_length'
},
'vad_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'sub_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'logits': {
1: 'logits_length'
},
}

View File

@ -0,0 +1,18 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(text_length):
return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value))
_run(_get_feed_dict(10))

View File

@ -0,0 +1,22 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(text_length):
return {'inputs': np.ones((1, text_length), dtype=np.int64),
'text_lengths': np.array([text_length,], dtype=np.int32),
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value))
_run(_get_feed_dict(10))

View File

@ -0,0 +1,26 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(feats_length):
return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value.shape))
_run(_get_feed_dict(100))
_run(_get_feed_dict(200))

View File

@ -5,7 +5,17 @@ from typing import Tuple
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract LM class
@ -27,3 +37,122 @@ class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
self, input: torch.Tensor, hidden: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
class LanguageModel(AbsESPnetModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
self.lm = lm
self.sos = 1
self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
def nll(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
max_length: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
else:
text = text[:, :max_length]
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
y, _ = self.lm(x, None)
# 3. Calc negative log likelihood
# nll: (BxL,)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, x_lengths
def batchify_nll(
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = self.nll(text, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = self.nll(
batch_text, batch_text_lengths, max_length=max_length
)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, text_lengths)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
return {}

View File

@ -1,131 +0,0 @@
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.lm.abs_model import AbsLM
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
class ESPnetLanguageModel(AbsESPnetModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
self.lm = lm
self.sos = 1
self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
def nll(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
max_length: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
else:
text = text[:, :max_length]
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
y, _ = self.lm(x, None)
# 3. Calc negative log likelihood
# nll: (BxL,)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, x_lengths
def batchify_nll(
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = self.nll(text, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = self.nll(
batch_text, batch_text_lengths, max_length=max_length
)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, text_lengths)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
return {}

View File

@ -192,7 +192,7 @@ class WindowDetector(object):
class E2EVadModel(nn.Module):
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@ -229,6 +229,7 @@ class E2EVadModel(nn.Module):
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
self.frontend = frontend
def AllResetDetection(self):
self.is_final = False
@ -477,8 +478,9 @@ class E2EVadModel(nn.Module):
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
self.DetectCommonFrames()
else:

View File

@ -10,7 +10,7 @@ from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
@ -27,7 +27,7 @@ from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@ -958,3 +958,231 @@ class SANMEncoderChunkOpt(AbsEncoder):
var_dict_tf[name_tf].shape))
return var_dict_torch_update
class SANMVadEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -5,12 +5,11 @@ from typing import Tuple
import torch
import torch.nn as nn
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):

View File

@ -6,8 +6,8 @@ import torch
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):

View File

@ -1,31 +0,0 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract class
To share the loss calculation way among different models,
We uses delegate pattern here:
The instance of this class should be passed to "LanguageModel"
>>> from funasr.punctuation.abs_model import AbsPunctuation
>>> punc = AbsPunctuation()
>>> model = ESPnetPunctuationModel(punc=punc)
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def with_vad(self) -> bool:
raise NotImplementedError

View File

@ -1,590 +0,0 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class SANMEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
encoder_outs = self.encoders0(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class SANMVadEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
#if len(self.interctc_layer_idx) == 0:
if False:
# Here, we should not use the repeat operation to do it for all layers.
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -1,12 +0,0 @@
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@ -134,7 +134,7 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
@functools.lru_cache()
def get_logger(name='torch_paraformer'):
def get_logger(name='funasr_torch'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will

View File

@ -1,5 +1,6 @@
from funasr_onnx import Paraformer
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0) # cpu
@ -8,7 +9,6 @@ model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0)
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
wav_path = "YourPath/xx.wav"
result = model(wav_path)

View File

@ -0,0 +1,8 @@
from funasr_onnx import CT_Transformer
model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model = CT_Transformer(model_dir)
text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
result = model(text_in)
print(result[0])

View File

@ -0,0 +1,15 @@
from funasr_onnx import CT_Transformer_VadRealtime
model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model = CT_Transformer_VadRealtime(model_dir)
text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
vads = text_in.split("|")
rec_result_all=""
param_dict = {"cache": []}
for vad in vads:
result = model(vad, param_dict=param_dict)
rec_result_all += result[0]
print(rec_result_all)

View File

@ -0,0 +1,30 @@
import soundfile
from funasr_onnx import Fsmn_vad
model_dir = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav"
model = Fsmn_vad(model_dir)
#offline vad
# result = model(wav_path)
# print(result)
#online vad
speech, sample_rate = soundfile.read(wav_path)
speech_length = speech.shape[0]
sample_offset = 0
step = 160 * 10
param_dict = {'in_cache': []}
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
is_final = True
else:
is_final = False
param_dict['is_final'] = is_final
segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
param_dict=param_dict)
print(segments_result)

View File

@ -1,2 +1,5 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .vad_bin import Fsmn_vad
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime

View File

@ -0,0 +1,249 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import numpy as np
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
logging = get_logger()
class CT_Transformer():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
config_file = os.path.join(model_dir, 'punc.yaml')
config = read_yaml(config_file)
self.converter = TokenIDConverter(config['token_list'])
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = 1
self.punc_list = config['punc_list']
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
self.punc_list[i] = ""
elif self.punc_list[i] == "?":
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
def __call__(self, text: Union[list, str], split_size=20):
split_text = code_mix_split_words(text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = []
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
data = {
"text": mini_sentence_id[None,:],
"text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
}
try:
outputs = self.infer(data['text'], data['text_lengths'])
y = outputs[0]
punctuations = np.argmax(y,axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
new_mini_sentence_punc += [int(x) for x in punctuations]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
words_with_punc.append(self.punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "":
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
def infer(self, feats: np.ndarray,
feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
class CT_Transformer_VadRealtime(CT_Transformer):
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4
):
super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
def __call__(self, text: str, param_dict: map, split_size=20):
cache_key = "cache"
assert cache_key in param_dict
cache = param_dict[cache_key]
if cache is not None and len(cache) > 0:
precache = "".join(cache)
else:
precache = ""
cache = []
full_text = precache + text
split_text = code_mix_split_words(full_text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
new_mini_sentence_punc = []
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = np.array([], dtype='int32')
sentence_punc_list = []
sentence_words_list = []
cache_pop_trigger_limit = 200
skip_num = 0
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
text_length = len(mini_sentence_id)
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
"vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
}
try:
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
y = outputs[0]
punctuations = np.argmax(y,axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
punctuations_np = [int(x) for x in punctuations]
new_mini_sentence_punc += punctuations_np
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
for i in range(0, len(sentence_words_list)):
if i > 0:
if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
sentence_words_list[i] = " " + sentence_words_list[i]
if skip_num < len(cache):
skip_num += 1
else:
words_with_punc.append(sentence_words_list[i])
if skip_num >= len(cache):
sentence_punc_list_out.append(sentence_punc_list[i])
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "" or sentence_punc_list[i] == "":
sentenceEnd = i
break
cache_out = sentence_words_list[sentenceEnd + 1:]
if sentence_out[-1] in self.punc_list:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
param_dict[cache_key] = cache_out
return sentence_out, sentence_punc_list_out, cache_out
def vad_mask(self, size, vad_pos, dtype=np.bool):
"""Create mask for decoder self-attention.
:param int size: size of mask
:param int vad_pos: index of vad index
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor (B, Lmax, Lmax)
"""
ret = np.ones((size, size), dtype=dtype)
if vad_pos <= 0 or vad_pos >= size:
return ret
sub_corner = np.zeros(
(vad_pos - 1, size - vad_pos), dtype=dtype)
ret[0:vad_pos - 1, vad_pos:] = sub_corner
return ret
def infer(self, feats: np.ndarray,
feats_len: np.ndarray,
vad_mask: np.ndarray,
sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
return outputs

View File

@ -0,0 +1,607 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import math
import numpy as np
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)
class E2EVadModel():
def __init__(self, vad_post_args: Dict[str, Any]):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
# self.encoder = encoder
# init variables
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
def AllResetDetection(self):
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
def ResetDetection(self):
self.continous_silence_frame_count = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if self.data_buf_all is None:
self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
self.data_buf = self.data_buf_all
else:
self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0]))
for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
self.decibel.append(
10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \
0.000001))
def ComputeScores(self, scores: np.ndarray) -> None:
# scores = self.encoder(feats, in_cache) # return B * T * D
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
if self.scores is None:
self.scores = scores # the first calculation
else:
self.scores = np.concatenate((self.scores, scores), axis=1)
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
while self.data_buf_start_frame < frame_idx:
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
self.PopDataBufTillFrame(start_frm)
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point:
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(self.data_buf))
if len(self.data_buf) < expected_sample_number:
print('error in calling pop data_buf\n')
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
self.output_data_buf[-1].Reset()
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
self.output_data_buf[-1].doa = 0
cur_seg = self.output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('warning\n')
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
data_to_pop = 0
if end_point_is_sent_end:
data_to_pop = expected_sample_number
else:
data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if data_to_pop > len(self.data_buf):
print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
data_to_pop = len(self.data_buf)
expected_sample_number = len(self.data_buf)
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
# cur_seg.buffer[out_pos ++] = data_buf_.back();
out_pos += 1
for sample_cpy_out in range(data_to_pop, expected_sample_number):
# cur_seg.buffer[out_pos++] = data_buf_.back()
out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('Something wrong with the VAD algorithm\n')
self.data_buf_start_frame += frm_cnt
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
if first_frm_is_start_point:
cur_seg.contain_seg_start_point = True
if last_frm_is_end_point:
cur_seg.contain_seg_end_point = True
def OnSilenceDetected(self, valid_frame: int):
self.lastest_confirmed_silence_frame = valid_frame
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame)
# silence_detected_callback_
# pass
def OnVoiceDetected(self, valid_frame: int) -> None:
self.latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
if self.vad_opts.do_start_point_detection:
pass
if self.confirmed_start_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_start_frame = start_frame
if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t)
if self.vad_opts.do_end_point_detection:
pass
if self.confirmed_end_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_end_frame = end_frame
if not fake_result:
self.sil_frame = 0
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
self.number_end_time_detected += 1
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
def GetLatency(self) -> int:
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
def LatencyFrmNumAtStartPoint(self) -> int:
vad_latency = self.windows_detector.GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int) -> FrameState:
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(self.sil_pdf_ids) > 0:
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
sum_score = total_score - sum_score
speech_prob = math.log(sum_score)
if self.vad_opts.output_frame_probs:
frame_prob = E2EVadFrameProb()
frame_prob.noise_prob = noise_prob
frame_prob.speech_prob = speech_prob
frame_prob.score = sum_score
frame_prob.frame_id = t
self.frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSpeech
else:
frame_state = FrameState.kFrameStateSil
else:
frame_state = FrameState.kFrameStateSil
if self.noise_average_decibel < -99.9:
self.noise_average_decibel = cur_decibel
else:
self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
self.vad_opts.noise_frame_num_used_for_snr
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def __call__(self, score: np.ndarray, waveform: np.ndarray,
is_final: bool = False, max_end_sil: int = 800
):
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(score)
if not is_final:
self.DetectCommonFrames()
else:
self.DetectLastFrames()
segments = []
for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not self.output_data_buf[i].contain_seg_start_point:
continue
if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
continue
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
if self.output_data_buf[i].contain_seg_end_point:
end_ms = self.output_data_buf[i].end_ms
self.next_seg = True
self.output_data_buf_offset += 1
else:
end_ms = -1
self.next_seg = False
segment = [start_ms, end_ms]
segment_batch.append(segment)
if segment_batch:
segments.append(segment_batch)
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
return 0
def DetectLastFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
else:
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
return 0
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = self.continous_silence_frame_count
self.continous_silence_frame_count = 0
self.pre_end_silence_detected = False
start_frame = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
self.OnVoiceStart(start_frame)
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
self.OnVoiceDetected(t)
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t)
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
pass
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSil2Sil == state_change:
self.continous_silence_frame_count += 1
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
or (is_final_frame and self.number_end_time_detected == 0):
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t)
self.OnVoiceStart(0, True)
self.OnVoiceEnd(0, True, False);
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
else:
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if self.vad_opts.do_extend:
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
lookback_frame -= 1
lookback_frame = max(0, lookback_frame)
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif self.vad_opts.do_extend and not is_final_frame:
if self.continous_silence_frame_count <= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()

View File

@ -188,7 +188,7 @@ class OrtInferSession():
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(None, input_dict)
return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
@ -215,6 +215,38 @@ class OrtInferSession():
if not model_path.is_file():
raise FileExistsError(f'{model_path} is not a file.')
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences
def code_mix_split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
@ -226,7 +258,7 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
@functools.lru_cache()
def get_logger(name='rapdi_paraformer'):
def get_logger(name='funasr_onnx'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will

View File

@ -0,0 +1,143 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
from .utils.frontend import WavFrontend
from .utils.e2e_vad import E2EVadModel
logging = get_logger()
class Fsmn_vad():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
max_end_sil: int = None,
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
config_file = os.path.join(model_dir, 'vad.yaml')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
config = read_yaml(config_file)
self.frontend = WavFrontend(
cmvn_file=cmvn_file,
**config['frontend_conf']
)
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = batch_size
self.vad_scorer = E2EVadModel(config["vad_post_conf"])
self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
self.encoder_conf = config["encoder_conf"]
def prepare_cache(self, in_cache: list = []):
if len(in_cache) > 0:
return in_cache
fsmn_layers = self.encoder_conf["fsmn_layers"]
proj_dim = self.encoder_conf["proj_dim"]
lorder = self.encoder_conf["lorder"]
for i in range(fsmn_layers):
cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
in_cache.append(cache)
return in_cache
def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
# waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
param_dict = kwargs.get('param_dict', dict())
is_final = param_dict.get('is_final', False)
audio_in_cache = param_dict.get('audio_in_cache', None)
audio_in_cum = audio_in
if audio_in_cache is not None:
audio_in_cum = np.concatenate((audio_in_cache, audio_in_cum))
param_dict['audio_in_cache'] = audio_in_cum
feats, feats_len = self.extract_feat([audio_in_cum])
in_cache = param_dict.get('in_cache', list())
in_cache = self.prepare_cache(in_cache)
beg_idx = param_dict.get('beg_idx',0)
feats = feats[:, beg_idx:beg_idx+8, :]
param_dict['beg_idx'] = beg_idx + feats.shape[1]
try:
inputs = [feats]
inputs.extend(in_cache)
scores, out_caches = self.infer(inputs)
param_dict['in_cache'] = out_caches
segments = self.vad_scorer(scores, audio_in[None, :], is_final=is_final, max_end_sil=self.max_end_sil)
# print(segments)
if len(segments) == 1 and segments[0][0][1] != -1:
self.frontend.reset_status()
except ONNXRuntimeError:
logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
segments = []
return segments
def load_data(self,
wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(
f'The type of {wav_content} is not in [str, np.ndarray, list]')
def extract_feat(self,
waveform_list: List[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, 'constant', constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer(feats)
scores, out_caches = outputs[0], outputs[1:]
return scores, out_caches

View File

@ -13,7 +13,7 @@ def get_readme():
MODULE_NAME = 'funasr_onnx'
VERSION_NUM = '0.0.2'
VERSION_NUM = '0.0.3'
setuptools.setup(
name=MODULE_NAME,

View File

@ -15,7 +15,7 @@ from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.lm.abs_model import AbsLM
from funasr.lm.espnet_model import ESPnetLanguageModel
from funasr.lm.abs_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
@ -83,7 +83,7 @@ class LMTask(AbsTask):
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetLanguageModel),
default=get_default_kwargs(LanguageModel),
help="The keyword arguments for model class.",
)
@ -178,7 +178,7 @@ class LMTask(AbsTask):
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
def build_model(cls, args: argparse.Namespace) -> LanguageModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@ -201,7 +201,7 @@ class LMTask(AbsTask):
# 2. Build ESPnetModel
# Assume the last-id is sos_and_eos
model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
# 3. Initialize
if args.init is not None:

View File

@ -14,10 +14,10 @@ from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@ -79,7 +79,7 @@ class PunctuationTask(AbsTask):
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(ESPnetPunctuationModel),
default=get_default_kwargs(PunctuationModel),
help="The keyword arguments for model class.",
)
@ -183,7 +183,7 @@ class PunctuationTask(AbsTask):
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@ -218,7 +218,7 @@ class PunctuationTask(AbsTask):
# Assume the last-id is sos_and_eos
if "punc_weight" in args.model_conf:
args.model_conf.pop("punc_weight")
model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
# FIXME(kamo): Should be done in model?
# 3. Initialize

View File

@ -40,7 +40,7 @@ from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
@ -81,6 +81,7 @@ frontend_choices = ClassChoices(
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
wav_frontend_online=WavFrontendOnline,
),
type_check=AbsFrontend,
default="default",
@ -291,7 +292,24 @@ class VADTask(AbsTask):
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
@ -301,6 +319,7 @@ class VADTask(AbsTask):
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
@ -325,6 +344,8 @@ class VADTask(AbsTask):
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
model.to(device)

View File

@ -1,3 +1,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Optional
from typing import Tuple
@ -7,13 +11,34 @@ import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class ESPnetPunctuationModel(AbsESPnetModel):
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract class
To share the loss calculation way among different models,
We uses delegate pattern here:
The instance of this class should be passed to "LanguageModel"
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def with_vad(self) -> bool:
raise NotImplementedError
class PunctuationModel(AbsESPnetModel):
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
@ -21,12 +46,12 @@ class ESPnetPunctuationModel(AbsESPnetModel):
self.punc_weight = torch.Tensor(punc_weight)
self.sos = 1
self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
#if self.punc_model.with_vad():
# if self.punc_model.with_vad():
# print("This is a vad puncuation model.")
def nll(
self,
text: torch.Tensor,
@ -54,7 +79,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.punc_model.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
@ -62,7 +87,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_model(text, text_lengths)
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
@ -75,7 +100,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@ -87,7 +113,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
def batchify_nll(self,
text: torch.Tensor,
punc: torch.Tensor,
@ -113,7 +139,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
@ -132,7 +158,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self,
text: torch.Tensor,
@ -146,15 +172,15 @@ class ESPnetPunctuationModel(AbsESPnetModel):
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {}
def inference(self,
text: torch.Tensor,
text_lengths: torch.Tensor,