This commit is contained in:
游雁 2023-12-21 13:29:37 +08:00
parent 00ea1186f9
commit c8bae0ec85
12 changed files with 552 additions and 359 deletions

View File

@ -0,0 +1,8 @@
cmd="funasr/bin/inference.py"
python $cmd \
+model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
+input="/Users/zhifu/Downloads/asr_example.wav" \
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_vad" \
+device="cpu" \

View File

@ -101,6 +101,7 @@ class AutoModel:
tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
# build frontend
frontend = kwargs.get("frontend", None)
@ -112,12 +113,10 @@ class AutoModel:
# build model
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
model.eval()
model.to(device)
kwargs["token_list"] = tokenizer.token_list
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:

View File

@ -145,7 +145,8 @@ def main(**kwargs):
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
if batch_sampler is not None:
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
@ -153,7 +154,6 @@ def main(**kwargs):
pin_memory=True)
trainer = Trainer(
model=model,
optim=optim,

View File

@ -24,6 +24,17 @@ class AudioDataset(torch.utils.data.Dataset):
super().__init__()
index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
self.index_ds = index_ds_class(path)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
@ -49,8 +60,13 @@ class AudioDataset(torch.utils.data.Dataset):
# pdb.set_trace()
source = item["source"]
data_src = load_audio(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)

View File

@ -0,0 +1,212 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.utils.register import register_class, registry_tables
@register_class("model_classes", "CTTransformer")
class CTTransformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
encoder: str = None,
encoder_conf: str = None,
vocab_size: int = -1,
punc_list: list = None,
punc_weight: list = None,
embed_unit: int = 128,
att_unit: int = 256,
dropout_rate: float = 0.5,
ignore_id: int = -1,
sos: int = 1,
eos: int = 2,
**kwargs,
):
super().__init__()
punc_size = len(punc_list)
if punc_weight is None:
punc_weight = [1] * punc_size
self.embed = nn.Embedding(vocab_size, embed_unit)
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.decoder = nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
self.ignore_id = ignore_id
self.sos = sos
self.eos = eos
def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
def with_vad(self):
return False
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
def nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
max_length: Optional[int] = None,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, :text_lengths.max()]
punc = punc[:, :text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
y, _ = self.punc_forward(text, text_lengths, vad_indexes)
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_forward(text, text_lengths)
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
from sklearn.metrics import f1_score
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
indices.squeeze(-1).detach().cpu().numpy(),
average='micro')
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
def forward(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def generate(self,
text: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
if self.with_vad():
assert vad_indexes is not None
return self.punc_forward(text, text_lengths, vad_indexes)
else:
return self.punc_forward(text, text_lengths)

View File

@ -1,130 +0,0 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.sanm.encoder import SANMEncoder as Encoder
class TargetDelayTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,
punc_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
padding_idx=0,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
def with_vad(self):
return False
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list

View File

@ -6,6 +6,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.utils.register import register_class, registry_tables
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
@ -156,7 +158,7 @@ num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
'''
@register_class("encoder_classes", "FSMN")
class FSMN(nn.Module):
def __init__(
self,
@ -227,7 +229,7 @@ lstride: left stride
rstride: right stride
'''
@register_class("encoder_classes", "DFSMN")
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):

View File

@ -1,33 +1,244 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import logging
import os
import json
import torch
from torch import nn
import math
from typing import Optional
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.base_model import FunASRModel
from funasr.models.model_class_factory import *
import time
from funasr.utils.register import register_class, registry_tables
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
**kwargs,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)
@register_class("model_classes", "FsmnVAD")
class FsmnVAD(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, encoder: str = None,
def __init__(self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
vad_post_args: Dict[str, Any] = None,
frontend=None):
**kwargs,
):
super().__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.vad_opts = VADXOptions(**kwargs)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
encoder_class = encoder_classes.get_class(encoder)
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
# init variables
@ -57,7 +268,6 @@ class FsmnVAD(nn.Module):
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.frontend = frontend
self.last_drop_frames = 0
def AllResetDetection(self):
@ -239,7 +449,7 @@ class FsmnVAD(nn.Module):
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int) -> FrameState:
def GetFrameState(self, t: int):
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
@ -285,7 +495,7 @@ class FsmnVAD(nn.Module):
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
):
if not in_cache:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
@ -313,6 +523,87 @@ class FsmnVAD(nn.Module):
self.AllResetDetection()
return segments, in_cache
def generate(self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
meta_data = {}
audio_sample_list = [data_in]
if isinstance(data_in, torch.Tensor): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
# b. Forward Encoder streaming
t_offset = 0
feats = speech
feats_len = speech_lengths.max().item()
waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
in_cache = kwargs.get("in_cache", {})
batch_size = kwargs.get("batch_size", 1)
step = min(feats_len, 6000)
segments = [[]] * batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final = True
else:
is_final = False
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final": is_final,
"in_cache": in_cache
}
segments_part, in_cache = self.forward(**batch)
if segments_part:
for batch_num in range(0, batch_size):
segments[batch_num] += segments_part[batch_num]
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = writer[f"{1}best_recog"]
results = []
for i in range(batch_size):
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
if ibest_writer is not None:
ibest_writer["text"][key[i]] = segments[i]
result_i = {"key": key[i], "value": segments[i]}
results.append(result_i)
return results, meta_data
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False, max_end_sil: int = 800
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@ -483,207 +774,3 @@ class FsmnVAD(nn.Module):
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)

View File

@ -42,8 +42,9 @@ class BaseTokenizer(ABC):
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with open('data.json', 'r', encoding='utf-8') as f:
self.token_list = json.loads(f.read())
with open(token_list, 'r', encoding='utf-8') as f:
self.token_list = json.load(f)
else:
self.token_list: List[str] = list(token_list)

View File

@ -120,6 +120,7 @@ def load_pretrained_model(
if ignore_init_mismatch:
src_state = filter_state_dict(dst_state, src_state)
# logging.info("Loaded src_state keys: {}".format(src_state.keys()))
logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
dst_state.update(src_state)
obj.load_state_dict(dst_state)

View File

@ -10,14 +10,11 @@ from setuptools import setup
requirements = {
"install": [
# "setuptools>=38.5.1",
"humanfriendly",
"scipy>=1.4.1",
"librosa",
"jamo", # For kss
"PyYAML>=5.1.2",
# "soundfile>=0.12.1",
# "h5py>=3.1.0",
"kaldiio>=2.17.0",
"torch_complex",
# "nltk>=3.4.5",
@ -32,7 +29,6 @@ requirements = {
# ENH
"pytorch_wpe",
"editdistance>=0.5.2",
"tensorboard",
# "g2p",
# "nara_wpe",
# PAI
@ -44,6 +40,7 @@ requirements = {
"hdbscan",
"umap",
"jaconv",
"hydra-core",
],
# train: The modules invoked when training only.
"train": [