diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 332013612..0538f6623 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -132,7 +132,8 @@ class AutoModel: self.punc_kwargs = punc_kwargs self.spk_model = spk_model self.spk_kwargs = spk_kwargs - self.model_path = kwargs.get("model_path", "./") + self.model_path = kwargs.get("model_path") + def build_model(self, **kwargs): diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py index 5af33fca4..ebb72a327 100644 --- a/funasr/datasets/audio_datasets/datasets.py +++ b/funasr/datasets/audio_datasets/datasets.py @@ -58,7 +58,7 @@ class AudioDataset(torch.utils.data.Dataset): data_src = load_audio_text_image_video(source, fs=self.fs) if self.preprocessor_speech: data_src = self.preprocessor_speech(data_src) - speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] + speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] target = item["target"] if self.preprocessor_text: diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index 9c896f118..c6e03e86e 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -399,9 +399,10 @@ class WavFrontendOnline(nn.Module): return feats_pad, feats_lens, lfr_splice_frame_idxs def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs + self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs ): is_final = kwargs.get("is_final", False) + cache = kwargs.get("cache", {}) if len(cache) == 0: self.init_cache(cache) diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index becfd56e3..76eee8189 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -15,7 +15,7 @@ from funasr.register import tables from typing import List, Tuple, Dict, Any, Optional from funasr.utils.datadir_writer import DatadirWriter -from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank class VadStateMachine(Enum): @@ -23,11 +23,13 @@ class VadStateMachine(Enum): kVadInStateInSpeechSegment = 2 kVadInStateEndPointDetected = 3 + class FrameState(Enum): kFrameStateInvalid = -1 kFrameStateSpeech = 1 kFrameStateSil = 0 + # final voice/unvoice state per frame class AudioChangeState(Enum): kChangeStateSpeech2Speech = 0 @@ -37,16 +39,19 @@ class AudioChangeState(Enum): kChangeStateNoBegin = 4 kChangeStateInvalid = 5 + class VadDetectMode(Enum): kVadSingleUtteranceDetectMode = 0 kVadMutipleUtteranceDetectMode = 1 + class VADXOptions: """ Author: Speech Lab of DAMO Academy, Alibaba Group Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ + def __init__( self, sample_rate: int = 16000, @@ -117,6 +122,7 @@ class E2EVadSpeechBufWithDoa(object): Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ + def __init__(self): self.start_ms = 0 self.end_ms = 0 @@ -140,6 +146,7 @@ class E2EVadFrameProb(object): Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ + def __init__(self): self.noise_prob = 0.0 self.speech_prob = 0.0 @@ -154,6 +161,7 @@ class WindowDetector(object): Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ + def __init__(self, window_size_ms: int, sil_to_speech_time: int, speech_to_sil_time: int, @@ -190,7 +198,7 @@ class WindowDetector(object): def GetWinSize(self) -> int: return int(self.win_size_frame) - def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState: + def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState: cur_frame_state = FrameState.kFrameStateSil if frameState == FrameState.kFrameStateSpeech: cur_frame_state = 1 @@ -220,13 +228,13 @@ class WindowDetector(object): def FrameSizeMs(self) -> int: return int(self.frame_size_ms) + class Stats(object): def __init__(self, sil_pdf_ids, max_end_sil_frame_cnt_thresh, speech_noise_thres, ): - self.data_buf_start_frame = 0 self.frm_cnt = 0 self.latest_confirmed_speech_frame = 0 @@ -255,6 +263,7 @@ class Stats(object): self.waveform = None self.last_drop_frames = 0 + @tables.register("model_classes", "FsmnVADStreaming") class FsmnVADStreaming(nn.Module): """ @@ -262,6 +271,7 @@ class FsmnVADStreaming(nn.Module): Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ + def __init__(self, encoder: str = None, encoder_conf: Optional[Dict] = None, @@ -275,7 +285,6 @@ class FsmnVADStreaming(nn.Module): encoder = encoder_class(**encoder_conf) self.encoder = encoder - def ResetDetection(self, cache: dict = {}): cache["stats"].continous_silence_frame_count = 0 cache["stats"].latest_confirmed_speech_frame = 0 @@ -292,7 +301,8 @@ class FsmnVADStreaming(nn.Module): drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) real_drop_frames = drop_frames - cache["stats"].last_drop_frames cache["stats"].last_drop_frames = drop_frames - cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int( + self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] @@ -300,7 +310,8 @@ class FsmnVADStreaming(nn.Module): frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) if cache["stats"].data_buf_all is None: - cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] + cache["stats"].data_buf_all = cache["stats"].waveform[ + 0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] cache["stats"].data_buf = cache["stats"].data_buf_all else: cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) @@ -319,15 +330,16 @@ class FsmnVADStreaming(nn.Module): else: cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1) - def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again + def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None: # need check again while cache["stats"].data_buf_start_frame < frame_idx: if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): cache["stats"].data_buf_start_frame += 1 - cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( - self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + cache["stats"].data_buf = cache["stats"].data_buf_all[ + (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( + self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, - last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None: + last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None: self.PopDataBufTillFrame(start_frm, cache=cache) expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) if last_frm_is_end_point: @@ -379,14 +391,15 @@ class FsmnVADStreaming(nn.Module): cache["stats"].lastest_confirmed_silence_frame = valid_frame if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: self.PopDataBufTillFrame(valid_frame, cache=cache) - # silence_detected_callback_ - # pass - def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None: + # silence_detected_callback_ + # pass + + def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None: cache["stats"].latest_confirmed_speech_frame = valid_frame self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache) - def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None: + def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None: if self.vad_opts.do_start_point_detection: pass if cache["stats"].confirmed_start_frame != -1: @@ -397,7 +410,7 @@ class FsmnVADStreaming(nn.Module): if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache) - def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None: + def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None: for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame): self.OnVoiceDetected(t, cache=cache) if self.vad_opts.do_end_point_detection: @@ -487,7 +500,8 @@ class FsmnVADStreaming(nn.Module): segment_batch = [] if len(cache["stats"].output_data_buf) > 0: for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): - if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ + if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not + cache["stats"].output_data_buf[ i].contain_seg_end_point): continue segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] @@ -499,9 +513,9 @@ class FsmnVADStreaming(nn.Module): # # reset class variables and clear the dict for the next query # self.AllResetDetection() return segments - + def init_cache(self, cache: dict = {}, **kwargs): - + cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) cache["encoder"] = {} @@ -528,12 +542,12 @@ class FsmnVADStreaming(nn.Module): cache: dict = {}, **kwargs, ): - + if len(cache) == 0: self.init_cache(cache, **kwargs) meta_data = {} - chunk_size = kwargs.get("chunk_size", 60000) # 50ms + chunk_size = kwargs.get("chunk_size", 60000) # 50ms chunk_stride_samples = int(chunk_size * frontend.fs / 1000) time1 = time.perf_counter() @@ -580,7 +594,6 @@ class FsmnVADStreaming(nn.Module): if len(segments_i) > 0: segments.extend(*segments_i) - cache["prev_samples"] = audio_sample[:-m] if _is_final: self.init_cache(cache) @@ -600,16 +613,15 @@ class FsmnVADStreaming(nn.Module): if ibest_writer is not None: ibest_writer["text"][key[0]] = segments - return results, meta_data - def DetectCommonFrames(self, cache: dict = {}) -> int: if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: return 0 for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) + frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, + cache=cache) self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) return 0 @@ -619,7 +631,8 @@ class FsmnVADStreaming(nn.Module): return 0 for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) + frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, + cache=cache) if i != 0: self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) else: @@ -627,7 +640,8 @@ class FsmnVADStreaming(nn.Module): return 0 - def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: + def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, + cache: dict = {}) -> None: tmp_cur_frm_state = FrameState.kFrameStateInvalid if cur_frm_state == FrameState.kFrameStateSpeech: if math.fabs(1.0) > self.vad_opts.fe_prior_thres: @@ -644,7 +658,8 @@ class FsmnVADStreaming(nn.Module): cache["stats"].pre_end_silence_detected = False start_frame = 0 if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) + start_frame = max(cache["stats"].data_buf_start_frame, + cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) self.OnVoiceStart(start_frame, cache=cache) cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment for t in range(start_frame + 1, cur_frm_idx + 1): @@ -696,7 +711,8 @@ class FsmnVADStreaming(nn.Module): if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: # silence timeout, return zero length decision if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( - cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ + cache[ + "stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ or (is_final_frame and cache["stats"].number_end_time_detected == 0): for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): self.OnSilenceDetected(t, cache=cache) @@ -707,7 +723,8 @@ class FsmnVADStreaming(nn.Module): if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: + if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[ + "stats"].max_end_sil_frame_cnt_thresh: lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) if self.vad_opts.do_extend: lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) @@ -733,4 +750,3 @@ class FsmnVADStreaming(nn.Module): self.ResetDetection(cache=cache) - diff --git a/funasr/models/scama/beam_search.py b/funasr/models/scama/beam_search.py index 8f0d751e5..b8aa876b5 100644 --- a/funasr/models/scama/beam_search.py +++ b/funasr/models/scama/beam_search.py @@ -11,7 +11,7 @@ from typing import Union import torch -from funasr.metrics import end_detect +from funasr.metrics.common import end_detect from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface from funasr.models.transformer.scorers.scorer_interface import ScorerInterface @@ -494,3 +494,468 @@ class BeamSearchScama(torch.nn.Module): else: remained_hyps.append(hyp) return remained_hyps + +class BeamSearchScamaStreaming(torch.nn.Module): + """Beam search implementation.""" + + def __init__( + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str] = None, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = None, + ): + """Initialize beam search. + + Args: + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + sos (int): Start of sequence id + eos (int): End of sequence id + token_list (list[str]): List of tokens for debug log + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + """ + super().__init__() + # set scorers + self.weights = weights + self.scorers = dict() + self.full_scorers = dict() + self.part_scorers = dict() + # this module dict is required for recursive cast + # `self.to(device, dtype)` in `recog.py` + self.nn_dict = torch.nn.ModuleDict() + for k, v in scorers.items(): + w = weights.get(k, 0) + if w == 0 or v is None: + continue + assert isinstance( + v, ScorerInterface + ), f"{k} ({type(v)}) does not implement ScorerInterface" + self.scorers[k] = v + if isinstance(v, PartialScorerInterface): + self.part_scorers[k] = v + else: + self.full_scorers[k] = v + if isinstance(v, torch.nn.Module): + self.nn_dict[k] = v + + # set configurations + self.sos = sos + self.eos = eos + self.token_list = token_list + self.pre_beam_size = int(pre_beam_ratio * beam_size) + self.beam_size = beam_size + self.n_vocab = vocab_size + if ( + pre_beam_score_key is not None + and pre_beam_score_key != "full" + and pre_beam_score_key not in self.full_scorers + ): + raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + self.pre_beam_score_key = pre_beam_score_key + self.do_pre_beam = ( + self.pre_beam_score_key is not None + and self.pre_beam_size < self.n_vocab + and len(self.part_scorers) > 0 + ) + + def init_hyp(self, x) -> List[Hypothesis]: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.init_state(x) + init_scores[k] = 0.0 + return [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor([self.sos], device=x.device), + ) + ] + + @staticmethod + def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: + """Append new token to prefix tokens. + + Args: + xs (torch.Tensor): The prefix token + x (int): The new token to append + + Returns: + torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device + + """ + x = torch.tensor([x], dtype=xs.dtype, device=xs.device) + return torch.cat((xs, x)) + + def score_full( + self, hyp: Hypothesis, + x: torch.Tensor, + x_mask: torch.Tensor = None, + pre_acoustic_embeds: torch.Tensor = None, + cache: dict={}, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache) + return scores, states + + def score_partial( + self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.part_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 1D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.part_scorers` + and tensor score values of shape: `(len(ids),)`, + and state dict that has string keys + and state values of `self.part_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + return scores, states + + def beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(self.n_vocab,)`. + ids (torch.Tensor): The partial token ids to compute topk + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The topk full token ids and partial token ids. + Their shapes are `(self.beam_size,)` + + """ + # no pre beam performed + if weighted_scores.size(0) == ids.size(0): + top_ids = weighted_scores.topk(self.beam_size)[1] + return top_ids, top_ids + + # mask pruned in pre-beam not to select in topk + tmp = weighted_scores[ids] + weighted_scores[:] = -float("inf") + weighted_scores[ids] = tmp + top_ids = weighted_scores.topk(self.beam_size)[1] + local_ids = weighted_scores[ids].topk(self.beam_size)[1] + return top_ids, local_ids + + @staticmethod + def merge_scores( + prev_scores: Dict[str, float], + next_full_scores: Dict[str, torch.Tensor], + full_idx: int, + next_part_scores: Dict[str, torch.Tensor], + part_idx: int, + ) -> Dict[str, torch.Tensor]: + """Merge scores for new hypothesis. + + Args: + prev_scores (Dict[str, float]): + The previous hypothesis scores by `self.scorers` + next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` + full_idx (int): The next token id for `next_full_scores` + next_part_scores (Dict[str, torch.Tensor]): + scores of partial tokens by `self.part_scorers` + part_idx (int): The new token id for `next_part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are scalar tensors by the scorers. + + """ + new_scores = dict() + for k, v in next_full_scores.items(): + new_scores[k] = prev_scores[k] + v[full_idx] + for k, v in next_part_scores.items(): + new_scores[k] = prev_scores[k] + v[part_idx] + return new_scores + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, d in self.part_scorers.items(): + new_states[k] = d.select_state(part_states[k], part_idx) + return new_states + + def search( + self, running_hyps: List[Hypothesis], + x: torch.Tensor, + x_mask: torch.Tensor = None, + pre_acoustic_embeds: torch.Tensor = None, + cache: dict={}, + ) -> List[Hypothesis]: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (List[Hypothesis]): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + List[Hypotheses]: Best sorted hypotheses + + """ + best_hyps = [] + part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam + for hyp in running_hyps: + # scoring + weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) + scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] + part_scores, part_states = self.score_partial(hyp, part_ids, x) + for k in self.part_scorers: + weighted_scores[part_ids] += self.weights[k] * part_scores[k] + # add previous hyp score + weighted_scores += hyp.score + + # update hyps + for j, part_j in zip(*self.beam(weighted_scores, part_ids)): + # will be (2 x beam at most) + best_hyps.append( + Hypothesis( + score=weighted_scores[j], + yseq=self.append_token(hyp.yseq, j), + scores=self.merge_scores( + hyp.scores, scores, j, part_scores, part_j + ), + states=self.merge_states(states, part_states, part_j), + ) + ) + + # sort and prune 2 x beam -> beam + best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ + : min(len(best_hyps), self.beam_size) + ] + return best_hyps + + def forward( + self, x: torch.Tensor, + scama_mask: torch.Tensor = None, + pre_acoustic_embeds: torch.Tensor = None, + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + maxlen: int = None, + minlen: int = 0, + cache:dict={}, + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length. + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + if maxlen is None: + # set length bounds + if maxlenratio == 0: + maxlen = x.shape[0] + elif maxlenratio < 0: + maxlen = -1 * int(maxlenratio) + else: + maxlen = max(1, int(maxlenratio * x.size(0))) + minlen = int(minlenratio * x.size(0)) + + logging.info("decoder input length: " + str(x.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + # running_hyps = self.init_hyp(x) + running_hyps = cache["running_hyps"] + ended_hyps = [] + for i in range(maxlen): + logging.debug("position " + str(i)) + mask_enc = None + # if scama_mask is not None: + # token_num_predictor = scama_mask.size(1) + # token_id_slice = min(i, token_num_predictor-1) + # mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :] + # # if mask_enc.size(1) == 0: + # # mask_enc = scama_mask[:, -2:-1, :] + # # # mask_enc = torch.zeros_like(mask_enc) + pre_acoustic_embeds_cur = None + if pre_acoustic_embeds is not None: + b, t, d = pre_acoustic_embeds.size() + pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device) + pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1) + token_id_slice = min(i, t) + pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :] + + best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"]) + # post process of one iteration + running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + # end detection + if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + logging.info(f"end detected at {i}") + break + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + for x in nbest_hyps: + yseq = "".join([self.token_list[x] for x in x.yseq]) + logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score)) + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info( + f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + ) + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "\n" + ) + return nbest_hyps + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], + ) -> List[Hypothesis]: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (List[Hypothesis]): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + List[Hypothesis]: The new running hypotheses. + + """ + logging.debug(f"the number of running hypotheses: {len(running_hyps)}") + if self.token_list is not None: + logging.debug( + "best hypo: " + + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + running_hyps = [ + h._replace(yseq=self.append_token(h.yseq, self.eos)) + for h in running_hyps + ] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in running_hyps: + if hyp.yseq[-1] == self.eos: + # e.g., Word LM needs to add final score + for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + s = d.final_score(hyp.states[k]) + hyp.scores[k] += s + hyp = hyp._replace(score=hyp.score + self.weights[k] * s) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + return remained_hyps diff --git a/funasr/models/scama/model.py b/funasr/models/scama/model.py index aec6fe329..32e16bded 100644 --- a/funasr/models/scama/model.py +++ b/funasr/models/scama/model.py @@ -436,7 +436,10 @@ class SCAMA(nn.Module): def init_beam_search(self, **kwargs, ): - from funasr.models.scama.beam_search import BeamSearchScama + + from funasr.models.scama.beam_search import BeamSearchScamaStreaming + + from funasr.models.transformer.scorers.ctc import CTCPrefixScorer from funasr.models.transformer.scorers.length_bonus import LengthBonus @@ -460,13 +463,14 @@ class SCAMA(nn.Module): scorers["ngram"] = ngram weights = dict( - decoder=1.0 - kwargs.get("decoding_ctc_weight"), + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), ctc=kwargs.get("decoding_ctc_weight", 0.0), lm=kwargs.get("lm_weight", 0.0), ngram=kwargs.get("ngram_weight", 0.0), length_bonus=kwargs.get("penalty", 0.0), ) - beam_search = BeamSearchScama( + + beam_search = BeamSearchScamaStreaming( beam_size=kwargs.get("beam_size", 2), weights=weights, scorers=scorers, @@ -499,7 +503,11 @@ class SCAMA(nn.Module): is_final=kwargs.get("is_final", False)) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - + if "running_hyps" not in cache: + running_hyps = self.beam_search.init_hyp(encoder_out) + cache["running_hyps"] = running_hyps + + # predictor predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, @@ -513,47 +521,30 @@ class SCAMA(nn.Module): if torch.max(pre_token_length) < 1: return [] - decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out, - encoder_out_lens, - pre_acoustic_embeds, - pre_token_length, - cache=cache - ) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - + maxlen = minlen = pre_token_length + if kwargs.get("is_final", False): + maxlen += kwargs.get("token_num_relax", 5) + minlen = max(0, minlen - kwargs.get("token_num_relax", 5)) + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache, + ) + + cache["running_hyps"] = nbest_hyps + nbest_hyps = nbest_hyps[: self.nbest] + results = [] - b, n, d = decoder_out.size() - if isinstance(key[0], (list, tuple)): - key = key[0] - for i in range(b): - x = encoder_out[i, :encoder_out_lens[i], :] - am_scores = decoder_out[i, :pre_token_length[i], :] - if self.beam_search is not None: - nbest_hyps = self.beam_search( - x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), - minlenratio=kwargs.get("minlenratio", 0.0) - ) - - nbest_hyps = nbest_hyps[: self.nbest] + for hyp in nbest_hyps: + # assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] else: - - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - [self.sos] + yseq.tolist() + [self.eos], device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - for nbest_idx, hyp in enumerate(nbest_hyps): - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) @@ -568,6 +559,8 @@ class SCAMA(nn.Module): return results def init_cache(self, cache: dict = {}, **kwargs): + device = kwargs.get("device", "cuda") + chunk_size = kwargs.get("chunk_size", [0, 10, 5]) encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0) decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0) @@ -575,10 +568,11 @@ class SCAMA(nn.Module): enc_output_size = kwargs["encoder_conf"]["output_size"] feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"] - cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), - "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, + + cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device), + "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size, "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device), "tail_chunk": False} cache["encoder"] = cache_encoder @@ -586,8 +580,10 @@ class SCAMA(nn.Module): "chunk_size": chunk_size} cache["decoder"] = cache_decoder cache["frontend"] = {} - cache["prev_samples"] = torch.empty(0) - + + + cache["prev_samples"] = torch.empty(0).to(device=device) + return cache def inference(self, @@ -603,7 +599,10 @@ class SCAMA(nn.Module): # init beamsearch is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None - if self.beam_search is None and (is_use_lm or is_use_ctc): + + if self.beam_search is None: + + logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index 62d6be80b..414c0d7ca 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -148,6 +148,7 @@ class Trainer: self._train_epoch(epoch) + if self.use_ddp or self.use_fsdp: dist.barrier() @@ -156,8 +157,8 @@ class Trainer: if self.use_ddp or self.use_fsdp: dist.barrier() - - + + if self.rank == 0: self._save_checkpoint(epoch) @@ -172,7 +173,8 @@ class Trainer: if self.use_ddp or self.use_fsdp: dist.barrier() - + + if self.writer: self.writer.close()