diff --git a/examples/industrial_data_pretraining/fsmn-vad/infer.sh b/examples/industrial_data_pretraining/fsmn-vad/infer.sh new file mode 100644 index 000000000..9bfd8ba1a --- /dev/null +++ b/examples/industrial_data_pretraining/fsmn-vad/infer.sh @@ -0,0 +1,8 @@ + +cmd="funasr/bin/inference.py" + +python $cmd \ ++model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \ ++input="/Users/zhifu/Downloads/asr_example.wav" \ ++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_vad" \ ++device="cpu" \ diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index fd884cdf6..50ea4d4b4 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -101,6 +101,7 @@ class AutoModel: tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower()) tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) kwargs["tokenizer"] = tokenizer + kwargs["token_list"] = tokenizer.token_list # build frontend frontend = kwargs.get("frontend", None) @@ -112,12 +113,10 @@ class AutoModel: # build model model_class = registry_tables.model_classes.get(kwargs["model"].lower()) - model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) + model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1) model.eval() model.to(device) - kwargs["token_list"] = tokenizer.token_list - # init_param init_param = kwargs.get("init_param", None) if init_param is not None: diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 8112002b3..1e06c5037 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -145,7 +145,8 @@ def main(**kwargs): # dataloader batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower()) - batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) + if batch_sampler is not None: + batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, batch_sampler=batch_sampler, @@ -153,7 +154,6 @@ def main(**kwargs): pin_memory=True) - trainer = Trainer( model=model, optim=optim, diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py index 353a3a0b9..d69d0b573 100644 --- a/funasr/datasets/audio_datasets/datasets.py +++ b/funasr/datasets/audio_datasets/datasets.py @@ -24,6 +24,17 @@ class AudioDataset(torch.utils.data.Dataset): super().__init__() index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower()) self.index_ds = index_ds_class(path) + preprocessor_speech = kwargs.get("preprocessor_speech", None) + if preprocessor_speech: + preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower()) + preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) + self.preprocessor_speech = preprocessor_speech + preprocessor_text = kwargs.get("preprocessor_text", None) + if preprocessor_text: + preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower()) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + self.preprocessor_text = preprocessor_text + self.frontend = frontend self.fs = 16000 if frontend is None else frontend.fs self.data_type = "sound" @@ -49,8 +60,13 @@ class AudioDataset(torch.utils.data.Dataset): # pdb.set_trace() source = item["source"] data_src = load_audio(source, fs=self.fs) + if self.preprocessor_speech: + data_src = self.preprocessor_speech(data_src) speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] + target = item["target"] + if self.preprocessor_text: + target = self.preprocessor_text(target) ids = self.tokenizer.encode(target) ids_lengths = len(ids) text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) diff --git a/funasr/models/ct_transformer/sanm_encoder.py b/funasr/models/ct_transformer/encoder.py similarity index 100% rename from funasr/models/ct_transformer/sanm_encoder.py rename to funasr/models/ct_transformer/encoder.py diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py new file mode 100644 index 000000000..31b2af2aa --- /dev/null +++ b/funasr/models/ct_transformer/model.py @@ -0,0 +1,212 @@ +from typing import Any +from typing import List +from typing import Tuple + +import torch +import torch.nn as nn + +from funasr.utils.register import register_class, registry_tables + +@register_class("model_classes", "CTTransformer") +class CTTransformer(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection + https://arxiv.org/pdf/2003.01309.pdf + """ + def __init__( + self, + encoder: str = None, + encoder_conf: str = None, + vocab_size: int = -1, + punc_list: list = None, + punc_weight: list = None, + embed_unit: int = 128, + att_unit: int = 256, + dropout_rate: float = 0.5, + ignore_id: int = -1, + sos: int = 1, + eos: int = 2, + **kwargs, + ): + super().__init__() + + punc_size = len(punc_list) + if punc_weight is None: + punc_weight = [1] * punc_size + + + self.embed = nn.Embedding(vocab_size, embed_unit) + encoder_class = registry_tables.encoder_classes.get(encoder.lower()) + encoder = encoder_class(**encoder_conf) + + self.decoder = nn.Linear(att_unit, punc_size) + self.encoder = encoder + self.punc_list = punc_list + self.punc_weight = punc_weight + self.ignore_id = ignore_id + self.sos = sos + self.eos = eos + + + + def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Compute loss value from buffer sequences. + + Args: + input (torch.Tensor): Input ids. (batch, len) + hidden (torch.Tensor): Target ids. (batch, len) + + """ + x = self.embed(input) + # mask = self._target_mask(input) + h, _, _ = self.encoder(x, text_lengths) + y = self.decoder(h) + return y, None + + def with_vad(self): + return False + + def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (vocab_size) + and next state for ys + + """ + y = y.unsqueeze(0) + h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1).squeeze(0) + return logp, cache + + def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, vocab_size)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] + + # batch decoding + h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list + + def nll( + self, + text: torch.Tensor, + punc: torch.Tensor, + text_lengths: torch.Tensor, + punc_lengths: torch.Tensor, + max_length: Optional[int] = None, + vad_indexes: Optional[torch.Tensor] = None, + vad_indexes_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute negative log likelihood(nll) + + Normally, this function is called in batchify_nll. + Args: + text: (Batch, Length) + punc: (Batch, Length) + text_lengths: (Batch,) + max_lengths: int + """ + batch_size = text.size(0) + # For data parallel + if max_length is None: + text = text[:, :text_lengths.max()] + punc = punc[:, :text_lengths.max()] + else: + text = text[:, :max_length] + punc = punc[:, :max_length] + + if self.with_vad(): + # Should be VadRealtimeTransformer + assert vad_indexes is not None + y, _ = self.punc_forward(text, text_lengths, vad_indexes) + else: + # Should be TargetDelayTransformer, + y, _ = self.punc_forward(text, text_lengths) + + # Calc negative log likelihood + # nll: (BxL,) + if self.training == False: + _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) + from sklearn.metrics import f1_score + f1_score = f1_score(punc.view(-1).detach().cpu().numpy(), + indices.squeeze(-1).detach().cpu().numpy(), + average='micro') + nll = torch.Tensor([f1_score]).repeat(text_lengths.sum()) + return nll, text_lengths + else: + self.punc_weight = self.punc_weight.to(punc.device) + nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", + ignore_index=self.ignore_id) + # nll: (BxL,) -> (BxL,) + if max_length is None: + nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) + else: + nll.masked_fill_( + make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1), + 0.0, + ) + # nll: (BxL,) -> (B, L) + nll = nll.view(batch_size, -1) + return nll, text_lengths + + + def forward( + self, + text: torch.Tensor, + punc: torch.Tensor, + text_lengths: torch.Tensor, + punc_lengths: torch.Tensor, + vad_indexes: Optional[torch.Tensor] = None, + vad_indexes_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes) + ntokens = y_lengths.sum() + loss = nll.sum() / ntokens + stats = dict(loss=loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) + return loss, stats, weight + + def generate(self, + text: torch.Tensor, + text_lengths: torch.Tensor, + vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]: + if self.with_vad(): + assert vad_indexes is not None + return self.punc_forward(text, text_lengths, vad_indexes) + else: + return self.punc_forward(text, text_lengths) \ No newline at end of file diff --git a/funasr/models/ct_transformer/target_delay_transformer.py b/funasr/models/ct_transformer/target_delay_transformer.py deleted file mode 100644 index 59884a3a9..000000000 --- a/funasr/models/ct_transformer/target_delay_transformer.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import Any -from typing import List -from typing import Tuple - -import torch -import torch.nn as nn - -from funasr.models.transformer.embedding import SinusoidalPositionEncoder -from funasr.models.sanm.encoder import SANMEncoder as Encoder - - -class TargetDelayTransformer(torch.nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection - https://arxiv.org/pdf/2003.01309.pdf - """ - def __init__( - self, - vocab_size: int, - punc_size: int, - pos_enc: str = None, - embed_unit: int = 128, - att_unit: int = 256, - head: int = 2, - unit: int = 1024, - layer: int = 4, - dropout_rate: float = 0.5, - ): - super().__init__() - if pos_enc == "sinusoidal": - # pos_enc_class = PositionalEncoding - pos_enc_class = SinusoidalPositionEncoder - elif pos_enc is None: - - def pos_enc_class(*args, **kwargs): - return nn.Sequential() # indentity - - else: - raise ValueError(f"unknown pos-enc option: {pos_enc}") - - self.embed = nn.Embedding(vocab_size, embed_unit) - self.encoder = Encoder( - input_size=embed_unit, - output_size=att_unit, - attention_heads=head, - linear_units=unit, - num_blocks=layer, - dropout_rate=dropout_rate, - input_layer="pe", - # pos_enc_class=pos_enc_class, - padding_idx=0, - ) - self.decoder = nn.Linear(att_unit, punc_size) - - -# def _target_mask(self, ys_in_pad): -# ys_mask = ys_in_pad != 0 -# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) -# return ys_mask.unsqueeze(-2) & m - - def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: - """Compute loss value from buffer sequences. - - Args: - input (torch.Tensor): Input ids. (batch, len) - hidden (torch.Tensor): Target ids. (batch, len) - - """ - x = self.embed(input) - # mask = self._target_mask(input) - h, _, _ = self.encoder(x, text_lengths) - y = self.decoder(h) - return y, None - - def with_vad(self): - return False - - def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: - """Score new token. - - Args: - y (torch.Tensor): 1D torch.int64 prefix tokens. - state: Scorer state for prefix tokens - x (torch.Tensor): encoder feature that generates ys. - - Returns: - tuple[torch.Tensor, Any]: Tuple of - torch.float32 scores for next token (vocab_size) - and next state for ys - - """ - y = y.unsqueeze(0) - h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1).squeeze(0) - return logp, cache - - def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: - """Score new token batch. - - Args: - ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). - states (List[Any]): Scorer states for prefix tokens. - xs (torch.Tensor): - The encoder feature that generates ys (n_batch, xlen, n_feat). - - Returns: - tuple[torch.Tensor, List[Any]]: Tuple of - batchfied scores for next token with shape of `(n_batch, vocab_size)` - and next state list for ys. - - """ - # merge states - n_batch = len(ys) - n_layers = len(self.encoder.encoders) - if states[0] is None: - batch_state = None - else: - # transpose state of [batch, layer] into [layer, batch] - batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] - - # batch decoding - h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) - h = self.decoder(h[:, -1]) - logp = h.log_softmax(dim=-1) - - # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] - return logp, state_list diff --git a/funasr/models/fsmn_vad/fsmn_encoder.py b/funasr/models/fsmn_vad/encoder.py similarity index 98% rename from funasr/models/fsmn_vad/fsmn_encoder.py rename to funasr/models/fsmn_vad/encoder.py index 38d164dfe..50e31fc32 100755 --- a/funasr/models/fsmn_vad/fsmn_encoder.py +++ b/funasr/models/fsmn_vad/encoder.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from funasr.utils.register import register_class, registry_tables + class LinearTransform(nn.Module): def __init__(self, input_dim, output_dim): @@ -156,7 +158,7 @@ num_syn: output dimension fsmn_layers: no. of sequential fsmn layers ''' - +@register_class("encoder_classes", "FSMN") class FSMN(nn.Module): def __init__( self, @@ -227,7 +229,7 @@ lstride: left stride rstride: right stride ''' - +@register_class("encoder_classes", "DFSMN") class DFSMN(nn.Module): def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1): diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py index cc3c87e3f..16f21dca9 100644 --- a/funasr/models/fsmn_vad/model.py +++ b/funasr/models/fsmn_vad/model.py @@ -1,33 +1,244 @@ from enum import Enum from typing import List, Tuple, Dict, Any - +import logging +import os +import json import torch from torch import nn import math from typing import Optional -from funasr.models.encoder.fsmn_encoder import FSMN -from funasr.models.base_model import FunASRModel -from funasr.models.model_class_factory import * +import time +from funasr.utils.register import register_class, registry_tables +from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank +from funasr.utils.datadir_writer import DatadirWriter +from torch.nn.utils.rnn import pad_sequence + +class VadStateMachine(Enum): + kVadInStateStartPointNotDetected = 1 + kVadInStateInSpeechSegment = 2 + kVadInStateEndPointDetected = 3 +class FrameState(Enum): + kFrameStateInvalid = -1 + kFrameStateSpeech = 1 + kFrameStateSil = 0 + + +# final voice/unvoice state per frame +class AudioChangeState(Enum): + kChangeStateSpeech2Speech = 0 + kChangeStateSpeech2Sil = 1 + kChangeStateSil2Sil = 2 + kChangeStateSil2Speech = 3 + kChangeStateNoBegin = 4 + kChangeStateInvalid = 5 + + +class VadDetectMode(Enum): + kVadSingleUtteranceDetectMode = 0 + kVadMutipleUtteranceDetectMode = 1 + + +class VADXOptions: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__( + self, + sample_rate: int = 16000, + detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, + snr_mode: int = 0, + max_end_silence_time: int = 800, + max_start_silence_time: int = 3000, + do_start_point_detection: bool = True, + do_end_point_detection: bool = True, + window_size_ms: int = 200, + sil_to_speech_time_thres: int = 150, + speech_to_sil_time_thres: int = 150, + speech_2_noise_ratio: float = 1.0, + do_extend: int = 1, + lookback_time_start_point: int = 200, + lookahead_time_end_point: int = 100, + max_single_segment_time: int = 60000, + nn_eval_block_size: int = 8, + dcd_block_size: int = 4, + snr_thres: int = -100.0, + noise_frame_num_used_for_snr: int = 100, + decibel_thres: int = -100.0, + speech_noise_thres: float = 0.6, + fe_prior_thres: float = 1e-4, + silence_pdf_num: int = 1, + sil_pdf_ids: List[int] = [0], + speech_noise_thresh_low: float = -0.1, + speech_noise_thresh_high: float = 0.3, + output_frame_probs: bool = False, + frame_in_ms: int = 10, + frame_length_ms: int = 25, + **kwargs, + ): + self.sample_rate = sample_rate + self.detect_mode = detect_mode + self.snr_mode = snr_mode + self.max_end_silence_time = max_end_silence_time + self.max_start_silence_time = max_start_silence_time + self.do_start_point_detection = do_start_point_detection + self.do_end_point_detection = do_end_point_detection + self.window_size_ms = window_size_ms + self.sil_to_speech_time_thres = sil_to_speech_time_thres + self.speech_to_sil_time_thres = speech_to_sil_time_thres + self.speech_2_noise_ratio = speech_2_noise_ratio + self.do_extend = do_extend + self.lookback_time_start_point = lookback_time_start_point + self.lookahead_time_end_point = lookahead_time_end_point + self.max_single_segment_time = max_single_segment_time + self.nn_eval_block_size = nn_eval_block_size + self.dcd_block_size = dcd_block_size + self.snr_thres = snr_thres + self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr + self.decibel_thres = decibel_thres + self.speech_noise_thres = speech_noise_thres + self.fe_prior_thres = fe_prior_thres + self.silence_pdf_num = silence_pdf_num + self.sil_pdf_ids = sil_pdf_ids + self.speech_noise_thresh_low = speech_noise_thresh_low + self.speech_noise_thresh_high = speech_noise_thresh_high + self.output_frame_probs = output_frame_probs + self.frame_in_ms = frame_in_ms + self.frame_length_ms = frame_length_ms + + +class E2EVadSpeechBufWithDoa(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): + self.start_ms = 0 + self.end_ms = 0 + self.buffer = [] + self.contain_seg_start_point = False + self.contain_seg_end_point = False + self.doa = 0 + + def Reset(self): + self.start_ms = 0 + self.end_ms = 0 + self.buffer = [] + self.contain_seg_start_point = False + self.contain_seg_end_point = False + self.doa = 0 + + +class E2EVadFrameProb(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): + self.noise_prob = 0.0 + self.speech_prob = 0.0 + self.score = 0.0 + self.frame_id = 0 + self.frm_state = 0 + + +class WindowDetector(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, window_size_ms: int, sil_to_speech_time: int, + speech_to_sil_time: int, frame_size_ms: int): + self.window_size_ms = window_size_ms + self.sil_to_speech_time = sil_to_speech_time + self.speech_to_sil_time = speech_to_sil_time + self.frame_size_ms = frame_size_ms + + self.win_size_frame = int(window_size_ms / frame_size_ms) + self.win_sum = 0 + self.win_state = [0] * self.win_size_frame # 初始化窗 + + self.cur_win_pos = 0 + self.pre_frame_state = FrameState.kFrameStateSil + self.cur_frame_state = FrameState.kFrameStateSil + self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) + self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) + + self.voice_last_frame_count = 0 + self.noise_last_frame_count = 0 + self.hydre_frame_count = 0 + + def Reset(self) -> None: + self.cur_win_pos = 0 + self.win_sum = 0 + self.win_state = [0] * self.win_size_frame + self.pre_frame_state = FrameState.kFrameStateSil + self.cur_frame_state = FrameState.kFrameStateSil + self.voice_last_frame_count = 0 + self.noise_last_frame_count = 0 + self.hydre_frame_count = 0 + + def GetWinSize(self) -> int: + return int(self.win_size_frame) + + def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: + cur_frame_state = FrameState.kFrameStateSil + if frameState == FrameState.kFrameStateSpeech: + cur_frame_state = 1 + elif frameState == FrameState.kFrameStateSil: + cur_frame_state = 0 + else: + return AudioChangeState.kChangeStateInvalid + self.win_sum -= self.win_state[self.cur_win_pos] + self.win_sum += cur_frame_state + self.win_state[self.cur_win_pos] = cur_frame_state + self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame + + if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: + self.pre_frame_state = FrameState.kFrameStateSpeech + return AudioChangeState.kChangeStateSil2Speech + + if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: + self.pre_frame_state = FrameState.kFrameStateSil + return AudioChangeState.kChangeStateSpeech2Sil + + if self.pre_frame_state == FrameState.kFrameStateSil: + return AudioChangeState.kChangeStateSil2Sil + if self.pre_frame_state == FrameState.kFrameStateSpeech: + return AudioChangeState.kChangeStateSpeech2Speech + return AudioChangeState.kChangeStateInvalid + + def FrameSizeMs(self) -> int: + return int(self.frame_size_ms) + + +@register_class("model_classes", "FsmnVAD") class FsmnVAD(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ - def __init__(self, encoder: str = None, + def __init__(self, + encoder: str = None, encoder_conf: Optional[Dict] = None, vad_post_args: Dict[str, Any] = None, - frontend=None): + **kwargs, + ): super().__init__() - self.vad_opts = VADXOptions(**vad_post_args) + self.vad_opts = VADXOptions(**kwargs) self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, self.vad_opts.sil_to_speech_time_thres, self.vad_opts.speech_to_sil_time_thres, self.vad_opts.frame_in_ms) - encoder_class = encoder_classes.get_class(encoder) + encoder_class = registry_tables.encoder_classes.get(encoder.lower()) encoder = encoder_class(**encoder_conf) self.encoder = encoder # init variables @@ -57,7 +268,6 @@ class FsmnVAD(nn.Module): self.data_buf = None self.data_buf_all = None self.waveform = None - self.frontend = frontend self.last_drop_frames = 0 def AllResetDetection(self): @@ -239,7 +449,7 @@ class FsmnVAD(nn.Module): vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) return vad_latency - def GetFrameState(self, t: int) -> FrameState: + def GetFrameState(self, t: int): frame_state = FrameState.kFrameStateInvalid cur_decibel = self.decibel[t] cur_snr = cur_decibel - self.noise_average_decibel @@ -285,7 +495,7 @@ class FsmnVAD(nn.Module): def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False - ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: + ): if not in_cache: self.AllResetDetection() self.waveform = waveform # compute decibel for each frame @@ -313,6 +523,87 @@ class FsmnVAD(nn.Module): self.AllResetDetection() return segments, in_cache + def generate(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + + meta_data = {} + audio_sample_list = [data_in] + if isinstance(data_in, torch.Tensor): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data[ + "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + + # b. Forward Encoder streaming + t_offset = 0 + feats = speech + feats_len = speech_lengths.max().item() + waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N] + in_cache = kwargs.get("in_cache", {}) + batch_size = kwargs.get("batch_size", 1) + step = min(feats_len, 6000) + segments = [[]] * batch_size + + for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): + if t_offset + step >= feats_len - 1: + step = feats_len - t_offset + is_final = True + else: + is_final = False + batch = { + "feats": feats[:, t_offset:t_offset + step, :], + "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)], + "is_final": is_final, + "in_cache": in_cache + } + + + segments_part, in_cache = self.forward(**batch) + if segments_part: + for batch_num in range(0, batch_size): + segments[batch_num] += segments_part[batch_num] + + ibest_writer = None + if ibest_writer is None and kwargs.get("output_dir") is not None: + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{1}best_recog"] + + results = [] + for i in range(batch_size): + + if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": + results[i] = json.dumps(results[i]) + + if ibest_writer is not None: + ibest_writer["text"][key[i]] = segments[i] + + result_i = {"key": key[i], "value": segments[i]} + results.append(result_i) + + return results, meta_data + def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: @@ -483,207 +774,3 @@ class FsmnVAD(nn.Module): -class VadStateMachine(Enum): - kVadInStateStartPointNotDetected = 1 - kVadInStateInSpeechSegment = 2 - kVadInStateEndPointDetected = 3 - - -class FrameState(Enum): - kFrameStateInvalid = -1 - kFrameStateSpeech = 1 - kFrameStateSil = 0 - - -# final voice/unvoice state per frame -class AudioChangeState(Enum): - kChangeStateSpeech2Speech = 0 - kChangeStateSpeech2Sil = 1 - kChangeStateSil2Sil = 2 - kChangeStateSil2Speech = 3 - kChangeStateNoBegin = 4 - kChangeStateInvalid = 5 - - -class VadDetectMode(Enum): - kVadSingleUtteranceDetectMode = 0 - kVadMutipleUtteranceDetectMode = 1 - - -class VADXOptions: - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__( - self, - sample_rate: int = 16000, - detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, - snr_mode: int = 0, - max_end_silence_time: int = 800, - max_start_silence_time: int = 3000, - do_start_point_detection: bool = True, - do_end_point_detection: bool = True, - window_size_ms: int = 200, - sil_to_speech_time_thres: int = 150, - speech_to_sil_time_thres: int = 150, - speech_2_noise_ratio: float = 1.0, - do_extend: int = 1, - lookback_time_start_point: int = 200, - lookahead_time_end_point: int = 100, - max_single_segment_time: int = 60000, - nn_eval_block_size: int = 8, - dcd_block_size: int = 4, - snr_thres: int = -100.0, - noise_frame_num_used_for_snr: int = 100, - decibel_thres: int = -100.0, - speech_noise_thres: float = 0.6, - fe_prior_thres: float = 1e-4, - silence_pdf_num: int = 1, - sil_pdf_ids: List[int] = [0], - speech_noise_thresh_low: float = -0.1, - speech_noise_thresh_high: float = 0.3, - output_frame_probs: bool = False, - frame_in_ms: int = 10, - frame_length_ms: int = 25, - ): - self.sample_rate = sample_rate - self.detect_mode = detect_mode - self.snr_mode = snr_mode - self.max_end_silence_time = max_end_silence_time - self.max_start_silence_time = max_start_silence_time - self.do_start_point_detection = do_start_point_detection - self.do_end_point_detection = do_end_point_detection - self.window_size_ms = window_size_ms - self.sil_to_speech_time_thres = sil_to_speech_time_thres - self.speech_to_sil_time_thres = speech_to_sil_time_thres - self.speech_2_noise_ratio = speech_2_noise_ratio - self.do_extend = do_extend - self.lookback_time_start_point = lookback_time_start_point - self.lookahead_time_end_point = lookahead_time_end_point - self.max_single_segment_time = max_single_segment_time - self.nn_eval_block_size = nn_eval_block_size - self.dcd_block_size = dcd_block_size - self.snr_thres = snr_thres - self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr - self.decibel_thres = decibel_thres - self.speech_noise_thres = speech_noise_thres - self.fe_prior_thres = fe_prior_thres - self.silence_pdf_num = silence_pdf_num - self.sil_pdf_ids = sil_pdf_ids - self.speech_noise_thresh_low = speech_noise_thresh_low - self.speech_noise_thresh_high = speech_noise_thresh_high - self.output_frame_probs = output_frame_probs - self.frame_in_ms = frame_in_ms - self.frame_length_ms = frame_length_ms - - -class E2EVadSpeechBufWithDoa(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self): - self.start_ms = 0 - self.end_ms = 0 - self.buffer = [] - self.contain_seg_start_point = False - self.contain_seg_end_point = False - self.doa = 0 - - def Reset(self): - self.start_ms = 0 - self.end_ms = 0 - self.buffer = [] - self.contain_seg_start_point = False - self.contain_seg_end_point = False - self.doa = 0 - - -class E2EVadFrameProb(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self): - self.noise_prob = 0.0 - self.speech_prob = 0.0 - self.score = 0.0 - self.frame_id = 0 - self.frm_state = 0 - - -class WindowDetector(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self, window_size_ms: int, sil_to_speech_time: int, - speech_to_sil_time: int, frame_size_ms: int): - self.window_size_ms = window_size_ms - self.sil_to_speech_time = sil_to_speech_time - self.speech_to_sil_time = speech_to_sil_time - self.frame_size_ms = frame_size_ms - - self.win_size_frame = int(window_size_ms / frame_size_ms) - self.win_sum = 0 - self.win_state = [0] * self.win_size_frame # 初始化窗 - - self.cur_win_pos = 0 - self.pre_frame_state = FrameState.kFrameStateSil - self.cur_frame_state = FrameState.kFrameStateSil - self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) - self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) - - self.voice_last_frame_count = 0 - self.noise_last_frame_count = 0 - self.hydre_frame_count = 0 - - def Reset(self) -> None: - self.cur_win_pos = 0 - self.win_sum = 0 - self.win_state = [0] * self.win_size_frame - self.pre_frame_state = FrameState.kFrameStateSil - self.cur_frame_state = FrameState.kFrameStateSil - self.voice_last_frame_count = 0 - self.noise_last_frame_count = 0 - self.hydre_frame_count = 0 - - def GetWinSize(self) -> int: - return int(self.win_size_frame) - - def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: - cur_frame_state = FrameState.kFrameStateSil - if frameState == FrameState.kFrameStateSpeech: - cur_frame_state = 1 - elif frameState == FrameState.kFrameStateSil: - cur_frame_state = 0 - else: - return AudioChangeState.kChangeStateInvalid - self.win_sum -= self.win_state[self.cur_win_pos] - self.win_sum += cur_frame_state - self.win_state[self.cur_win_pos] = cur_frame_state - self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame - - if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: - self.pre_frame_state = FrameState.kFrameStateSpeech - return AudioChangeState.kChangeStateSil2Speech - - if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: - self.pre_frame_state = FrameState.kFrameStateSil - return AudioChangeState.kChangeStateSpeech2Sil - - if self.pre_frame_state == FrameState.kFrameStateSil: - return AudioChangeState.kChangeStateSil2Sil - if self.pre_frame_state == FrameState.kFrameStateSpeech: - return AudioChangeState.kChangeStateSpeech2Speech - return AudioChangeState.kChangeStateInvalid - - def FrameSizeMs(self) -> int: - return int(self.frame_size_ms) - - diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py index 349ebc0ed..d43d7b2b7 100644 --- a/funasr/tokenizer/abs_tokenizer.py +++ b/funasr/tokenizer/abs_tokenizer.py @@ -42,8 +42,9 @@ class BaseTokenizer(ABC): self.token_list_repr = str(token_list) self.token_list: List[str] = [] - with open('data.json', 'r', encoding='utf-8') as f: - self.token_list = json.loads(f.read()) + with open(token_list, 'r', encoding='utf-8') as f: + self.token_list = json.load(f) + else: self.token_list: List[str] = list(token_list) diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index b54f7771f..963d734d3 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -120,6 +120,7 @@ def load_pretrained_model( if ignore_init_mismatch: src_state = filter_state_dict(dst_state, src_state) - # logging.info("Loaded src_state keys: {}".format(src_state.keys())) + logging.debug("Loaded src_state keys: {}".format(src_state.keys())) + logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) dst_state.update(src_state) obj.load_state_dict(dst_state) diff --git a/setup.py b/setup.py index a1e47af04..ecd3d3dc8 100644 --- a/setup.py +++ b/setup.py @@ -10,14 +10,11 @@ from setuptools import setup requirements = { "install": [ - # "setuptools>=38.5.1", - "humanfriendly", "scipy>=1.4.1", "librosa", "jamo", # For kss "PyYAML>=5.1.2", # "soundfile>=0.12.1", - # "h5py>=3.1.0", "kaldiio>=2.17.0", "torch_complex", # "nltk>=3.4.5", @@ -32,7 +29,6 @@ requirements = { # ENH "pytorch_wpe", "editdistance>=0.5.2", - "tensorboard", # "g2p", # "nara_wpe", # PAI @@ -44,6 +40,7 @@ requirements = { "hdbscan", "umap", "jaconv", + "hydra-core", ], # train: The modules invoked when training only. "train": [