mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Funasr1.0 (#1277)
* funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * update with main (#1264) * Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * bug fix --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com> * funasr1.0 sanm scama * funasr1.0 infer_after_finetune * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 finetune * funasr1.0 finetune * funasr1.0 finetune * funasr1.0 finetune --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
This commit is contained in:
parent
2cca8104d2
commit
37d7764ecf
@ -132,7 +132,8 @@ class AutoModel:
|
|||||||
self.punc_kwargs = punc_kwargs
|
self.punc_kwargs = punc_kwargs
|
||||||
self.spk_model = spk_model
|
self.spk_model = spk_model
|
||||||
self.spk_kwargs = spk_kwargs
|
self.spk_kwargs = spk_kwargs
|
||||||
self.model_path = kwargs.get("model_path", "./")
|
self.model_path = kwargs.get("model_path")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(self, **kwargs):
|
def build_model(self, **kwargs):
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class AudioDataset(torch.utils.data.Dataset):
|
|||||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||||
if self.preprocessor_speech:
|
if self.preprocessor_speech:
|
||||||
data_src = self.preprocessor_speech(data_src)
|
data_src = self.preprocessor_speech(data_src)
|
||||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
|
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
|
||||||
|
|
||||||
target = item["target"]
|
target = item["target"]
|
||||||
if self.preprocessor_text:
|
if self.preprocessor_text:
|
||||||
|
|||||||
@ -399,9 +399,10 @@ class WavFrontendOnline(nn.Module):
|
|||||||
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs
|
self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs
|
||||||
):
|
):
|
||||||
is_final = kwargs.get("is_final", False)
|
is_final = kwargs.get("is_final", False)
|
||||||
|
cache = kwargs.get("cache", {})
|
||||||
if len(cache) == 0:
|
if len(cache) == 0:
|
||||||
self.init_cache(cache)
|
self.init_cache(cache)
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from funasr.register import tables
|
|||||||
from typing import List, Tuple, Dict, Any, Optional
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
|
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||||
|
|
||||||
|
|
||||||
class VadStateMachine(Enum):
|
class VadStateMachine(Enum):
|
||||||
@ -23,11 +23,13 @@ class VadStateMachine(Enum):
|
|||||||
kVadInStateInSpeechSegment = 2
|
kVadInStateInSpeechSegment = 2
|
||||||
kVadInStateEndPointDetected = 3
|
kVadInStateEndPointDetected = 3
|
||||||
|
|
||||||
|
|
||||||
class FrameState(Enum):
|
class FrameState(Enum):
|
||||||
kFrameStateInvalid = -1
|
kFrameStateInvalid = -1
|
||||||
kFrameStateSpeech = 1
|
kFrameStateSpeech = 1
|
||||||
kFrameStateSil = 0
|
kFrameStateSil = 0
|
||||||
|
|
||||||
|
|
||||||
# final voice/unvoice state per frame
|
# final voice/unvoice state per frame
|
||||||
class AudioChangeState(Enum):
|
class AudioChangeState(Enum):
|
||||||
kChangeStateSpeech2Speech = 0
|
kChangeStateSpeech2Speech = 0
|
||||||
@ -37,16 +39,19 @@ class AudioChangeState(Enum):
|
|||||||
kChangeStateNoBegin = 4
|
kChangeStateNoBegin = 4
|
||||||
kChangeStateInvalid = 5
|
kChangeStateInvalid = 5
|
||||||
|
|
||||||
|
|
||||||
class VadDetectMode(Enum):
|
class VadDetectMode(Enum):
|
||||||
kVadSingleUtteranceDetectMode = 0
|
kVadSingleUtteranceDetectMode = 0
|
||||||
kVadMutipleUtteranceDetectMode = 1
|
kVadMutipleUtteranceDetectMode = 1
|
||||||
|
|
||||||
|
|
||||||
class VADXOptions:
|
class VADXOptions:
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||||
https://arxiv.org/abs/1803.05030
|
https://arxiv.org/abs/1803.05030
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
@ -117,6 +122,7 @@ class E2EVadSpeechBufWithDoa(object):
|
|||||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||||
https://arxiv.org/abs/1803.05030
|
https://arxiv.org/abs/1803.05030
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.start_ms = 0
|
self.start_ms = 0
|
||||||
self.end_ms = 0
|
self.end_ms = 0
|
||||||
@ -140,6 +146,7 @@ class E2EVadFrameProb(object):
|
|||||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||||
https://arxiv.org/abs/1803.05030
|
https://arxiv.org/abs/1803.05030
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.noise_prob = 0.0
|
self.noise_prob = 0.0
|
||||||
self.speech_prob = 0.0
|
self.speech_prob = 0.0
|
||||||
@ -154,6 +161,7 @@ class WindowDetector(object):
|
|||||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||||
https://arxiv.org/abs/1803.05030
|
https://arxiv.org/abs/1803.05030
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size_ms: int,
|
def __init__(self, window_size_ms: int,
|
||||||
sil_to_speech_time: int,
|
sil_to_speech_time: int,
|
||||||
speech_to_sil_time: int,
|
speech_to_sil_time: int,
|
||||||
@ -190,7 +198,7 @@ class WindowDetector(object):
|
|||||||
def GetWinSize(self) -> int:
|
def GetWinSize(self) -> int:
|
||||||
return int(self.win_size_frame)
|
return int(self.win_size_frame)
|
||||||
|
|
||||||
def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
|
def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
|
||||||
cur_frame_state = FrameState.kFrameStateSil
|
cur_frame_state = FrameState.kFrameStateSil
|
||||||
if frameState == FrameState.kFrameStateSpeech:
|
if frameState == FrameState.kFrameStateSpeech:
|
||||||
cur_frame_state = 1
|
cur_frame_state = 1
|
||||||
@ -220,13 +228,13 @@ class WindowDetector(object):
|
|||||||
def FrameSizeMs(self) -> int:
|
def FrameSizeMs(self) -> int:
|
||||||
return int(self.frame_size_ms)
|
return int(self.frame_size_ms)
|
||||||
|
|
||||||
|
|
||||||
class Stats(object):
|
class Stats(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
sil_pdf_ids,
|
sil_pdf_ids,
|
||||||
max_end_sil_frame_cnt_thresh,
|
max_end_sil_frame_cnt_thresh,
|
||||||
speech_noise_thres,
|
speech_noise_thres,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.data_buf_start_frame = 0
|
self.data_buf_start_frame = 0
|
||||||
self.frm_cnt = 0
|
self.frm_cnt = 0
|
||||||
self.latest_confirmed_speech_frame = 0
|
self.latest_confirmed_speech_frame = 0
|
||||||
@ -255,6 +263,7 @@ class Stats(object):
|
|||||||
self.waveform = None
|
self.waveform = None
|
||||||
self.last_drop_frames = 0
|
self.last_drop_frames = 0
|
||||||
|
|
||||||
|
|
||||||
@tables.register("model_classes", "FsmnVADStreaming")
|
@tables.register("model_classes", "FsmnVADStreaming")
|
||||||
class FsmnVADStreaming(nn.Module):
|
class FsmnVADStreaming(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -262,6 +271,7 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||||
https://arxiv.org/abs/1803.05030
|
https://arxiv.org/abs/1803.05030
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
encoder: str = None,
|
encoder: str = None,
|
||||||
encoder_conf: Optional[Dict] = None,
|
encoder_conf: Optional[Dict] = None,
|
||||||
@ -275,7 +285,6 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
encoder = encoder_class(**encoder_conf)
|
encoder = encoder_class(**encoder_conf)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
|
||||||
|
|
||||||
def ResetDetection(self, cache: dict = {}):
|
def ResetDetection(self, cache: dict = {}):
|
||||||
cache["stats"].continous_silence_frame_count = 0
|
cache["stats"].continous_silence_frame_count = 0
|
||||||
cache["stats"].latest_confirmed_speech_frame = 0
|
cache["stats"].latest_confirmed_speech_frame = 0
|
||||||
@ -292,7 +301,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
|
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
|
||||||
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
|
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
|
||||||
cache["stats"].last_drop_frames = drop_frames
|
cache["stats"].last_drop_frames = drop_frames
|
||||||
cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
|
cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
|
||||||
|
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
|
||||||
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
|
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
|
||||||
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
|
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
|
||||||
|
|
||||||
@ -300,7 +310,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
|
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
|
||||||
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
|
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
|
||||||
if cache["stats"].data_buf_all is None:
|
if cache["stats"].data_buf_all is None:
|
||||||
cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
|
cache["stats"].data_buf_all = cache["stats"].waveform[
|
||||||
|
0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
|
||||||
cache["stats"].data_buf = cache["stats"].data_buf_all
|
cache["stats"].data_buf = cache["stats"].data_buf_all
|
||||||
else:
|
else:
|
||||||
cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
|
cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
|
||||||
@ -319,15 +330,16 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
else:
|
else:
|
||||||
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
|
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
|
||||||
|
|
||||||
def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again
|
def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None: # need check again
|
||||||
while cache["stats"].data_buf_start_frame < frame_idx:
|
while cache["stats"].data_buf_start_frame < frame_idx:
|
||||||
if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
|
if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
|
||||||
cache["stats"].data_buf_start_frame += 1
|
cache["stats"].data_buf_start_frame += 1
|
||||||
cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
|
cache["stats"].data_buf = cache["stats"].data_buf_all[
|
||||||
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
|
(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
|
||||||
|
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
|
||||||
|
|
||||||
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
|
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
|
||||||
last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
|
last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
|
||||||
self.PopDataBufTillFrame(start_frm, cache=cache)
|
self.PopDataBufTillFrame(start_frm, cache=cache)
|
||||||
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
|
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
|
||||||
if last_frm_is_end_point:
|
if last_frm_is_end_point:
|
||||||
@ -379,14 +391,15 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
cache["stats"].lastest_confirmed_silence_frame = valid_frame
|
cache["stats"].lastest_confirmed_silence_frame = valid_frame
|
||||||
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||||
self.PopDataBufTillFrame(valid_frame, cache=cache)
|
self.PopDataBufTillFrame(valid_frame, cache=cache)
|
||||||
# silence_detected_callback_
|
|
||||||
# pass
|
|
||||||
|
|
||||||
def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
|
# silence_detected_callback_
|
||||||
|
# pass
|
||||||
|
|
||||||
|
def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
|
||||||
cache["stats"].latest_confirmed_speech_frame = valid_frame
|
cache["stats"].latest_confirmed_speech_frame = valid_frame
|
||||||
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
|
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
|
||||||
|
|
||||||
def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
|
def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
|
||||||
if self.vad_opts.do_start_point_detection:
|
if self.vad_opts.do_start_point_detection:
|
||||||
pass
|
pass
|
||||||
if cache["stats"].confirmed_start_frame != -1:
|
if cache["stats"].confirmed_start_frame != -1:
|
||||||
@ -397,7 +410,7 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||||
self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
|
self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
|
||||||
|
|
||||||
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
|
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
|
||||||
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
|
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
|
||||||
self.OnVoiceDetected(t, cache=cache)
|
self.OnVoiceDetected(t, cache=cache)
|
||||||
if self.vad_opts.do_end_point_detection:
|
if self.vad_opts.do_end_point_detection:
|
||||||
@ -487,7 +500,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
segment_batch = []
|
segment_batch = []
|
||||||
if len(cache["stats"].output_data_buf) > 0:
|
if len(cache["stats"].output_data_buf) > 0:
|
||||||
for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
|
for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
|
||||||
if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
|
if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
|
||||||
|
cache["stats"].output_data_buf[
|
||||||
i].contain_seg_end_point):
|
i].contain_seg_end_point):
|
||||||
continue
|
continue
|
||||||
segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
|
segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
|
||||||
@ -499,9 +513,9 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
# # reset class variables and clear the dict for the next query
|
# # reset class variables and clear the dict for the next query
|
||||||
# self.AllResetDetection()
|
# self.AllResetDetection()
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
def init_cache(self, cache: dict = {}, **kwargs):
|
def init_cache(self, cache: dict = {}, **kwargs):
|
||||||
|
|
||||||
cache["frontend"] = {}
|
cache["frontend"] = {}
|
||||||
cache["prev_samples"] = torch.empty(0)
|
cache["prev_samples"] = torch.empty(0)
|
||||||
cache["encoder"] = {}
|
cache["encoder"] = {}
|
||||||
@ -528,12 +542,12 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
cache: dict = {},
|
cache: dict = {},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if len(cache) == 0:
|
if len(cache) == 0:
|
||||||
self.init_cache(cache, **kwargs)
|
self.init_cache(cache, **kwargs)
|
||||||
|
|
||||||
meta_data = {}
|
meta_data = {}
|
||||||
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
|
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
|
||||||
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
|
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
|
||||||
|
|
||||||
time1 = time.perf_counter()
|
time1 = time.perf_counter()
|
||||||
@ -580,7 +594,6 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
if len(segments_i) > 0:
|
if len(segments_i) > 0:
|
||||||
segments.extend(*segments_i)
|
segments.extend(*segments_i)
|
||||||
|
|
||||||
|
|
||||||
cache["prev_samples"] = audio_sample[:-m]
|
cache["prev_samples"] = audio_sample[:-m]
|
||||||
if _is_final:
|
if _is_final:
|
||||||
self.init_cache(cache)
|
self.init_cache(cache)
|
||||||
@ -600,16 +613,15 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
if ibest_writer is not None:
|
if ibest_writer is not None:
|
||||||
ibest_writer["text"][key[0]] = segments
|
ibest_writer["text"][key[0]] = segments
|
||||||
|
|
||||||
|
|
||||||
return results, meta_data
|
return results, meta_data
|
||||||
|
|
||||||
|
|
||||||
def DetectCommonFrames(self, cache: dict = {}) -> int:
|
def DetectCommonFrames(self, cache: dict = {}) -> int:
|
||||||
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
||||||
return 0
|
return 0
|
||||||
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
||||||
frame_state = FrameState.kFrameStateInvalid
|
frame_state = FrameState.kFrameStateInvalid
|
||||||
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
|
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
|
||||||
|
cache=cache)
|
||||||
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
|
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
@ -619,7 +631,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
return 0
|
return 0
|
||||||
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
||||||
frame_state = FrameState.kFrameStateInvalid
|
frame_state = FrameState.kFrameStateInvalid
|
||||||
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
|
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
|
||||||
|
cache=cache)
|
||||||
if i != 0:
|
if i != 0:
|
||||||
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
|
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
|
||||||
else:
|
else:
|
||||||
@ -627,7 +640,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
|
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
|
||||||
|
cache: dict = {}) -> None:
|
||||||
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
||||||
if cur_frm_state == FrameState.kFrameStateSpeech:
|
if cur_frm_state == FrameState.kFrameStateSpeech:
|
||||||
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
|
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
|
||||||
@ -644,7 +658,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
cache["stats"].pre_end_silence_detected = False
|
cache["stats"].pre_end_silence_detected = False
|
||||||
start_frame = 0
|
start_frame = 0
|
||||||
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||||
start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
|
start_frame = max(cache["stats"].data_buf_start_frame,
|
||||||
|
cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
|
||||||
self.OnVoiceStart(start_frame, cache=cache)
|
self.OnVoiceStart(start_frame, cache=cache)
|
||||||
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
|
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
|
||||||
for t in range(start_frame + 1, cur_frm_idx + 1):
|
for t in range(start_frame + 1, cur_frm_idx + 1):
|
||||||
@ -696,7 +711,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
||||||
# silence timeout, return zero length decision
|
# silence timeout, return zero length decision
|
||||||
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
|
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
|
||||||
cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
|
cache[
|
||||||
|
"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
|
||||||
or (is_final_frame and cache["stats"].number_end_time_detected == 0):
|
or (is_final_frame and cache["stats"].number_end_time_detected == 0):
|
||||||
for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
|
for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
|
||||||
self.OnSilenceDetected(t, cache=cache)
|
self.OnSilenceDetected(t, cache=cache)
|
||||||
@ -707,7 +723,8 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
|
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
|
||||||
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
|
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
|
||||||
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
||||||
if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
|
if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
|
||||||
|
"stats"].max_end_sil_frame_cnt_thresh:
|
||||||
lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
|
lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
|
||||||
if self.vad_opts.do_extend:
|
if self.vad_opts.do_extend:
|
||||||
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
|
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
|
||||||
@ -733,4 +750,3 @@ class FsmnVADStreaming(nn.Module):
|
|||||||
self.ResetDetection(cache=cache)
|
self.ResetDetection(cache=cache)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from funasr.metrics import end_detect
|
from funasr.metrics.common import end_detect
|
||||||
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
|
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
|
||||||
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
|
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
|
||||||
|
|
||||||
@ -494,3 +494,468 @@ class BeamSearchScama(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
remained_hyps.append(hyp)
|
remained_hyps.append(hyp)
|
||||||
return remained_hyps
|
return remained_hyps
|
||||||
|
|
||||||
|
class BeamSearchScamaStreaming(torch.nn.Module):
|
||||||
|
"""Beam search implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scorers: Dict[str, ScorerInterface],
|
||||||
|
weights: Dict[str, float],
|
||||||
|
beam_size: int,
|
||||||
|
vocab_size: int,
|
||||||
|
sos: int,
|
||||||
|
eos: int,
|
||||||
|
token_list: List[str] = None,
|
||||||
|
pre_beam_ratio: float = 1.5,
|
||||||
|
pre_beam_score_key: str = None,
|
||||||
|
):
|
||||||
|
"""Initialize beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||||
|
e.g., Decoder, CTCPrefixScorer, LM
|
||||||
|
The scorer will be ignored if it is `None`
|
||||||
|
weights (dict[str, float]): Dict of weights for each scorers
|
||||||
|
The scorer will be ignored if its weight is 0
|
||||||
|
beam_size (int): The number of hypotheses kept during search
|
||||||
|
vocab_size (int): The number of vocabulary
|
||||||
|
sos (int): Start of sequence id
|
||||||
|
eos (int): End of sequence id
|
||||||
|
token_list (list[str]): List of tokens for debug log
|
||||||
|
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||||
|
pre_beam_ratio (float): beam size in the pre-beam search
|
||||||
|
will be `int(pre_beam_ratio * beam_size)`
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# set scorers
|
||||||
|
self.weights = weights
|
||||||
|
self.scorers = dict()
|
||||||
|
self.full_scorers = dict()
|
||||||
|
self.part_scorers = dict()
|
||||||
|
# this module dict is required for recursive cast
|
||||||
|
# `self.to(device, dtype)` in `recog.py`
|
||||||
|
self.nn_dict = torch.nn.ModuleDict()
|
||||||
|
for k, v in scorers.items():
|
||||||
|
w = weights.get(k, 0)
|
||||||
|
if w == 0 or v is None:
|
||||||
|
continue
|
||||||
|
assert isinstance(
|
||||||
|
v, ScorerInterface
|
||||||
|
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
||||||
|
self.scorers[k] = v
|
||||||
|
if isinstance(v, PartialScorerInterface):
|
||||||
|
self.part_scorers[k] = v
|
||||||
|
else:
|
||||||
|
self.full_scorers[k] = v
|
||||||
|
if isinstance(v, torch.nn.Module):
|
||||||
|
self.nn_dict[k] = v
|
||||||
|
|
||||||
|
# set configurations
|
||||||
|
self.sos = sos
|
||||||
|
self.eos = eos
|
||||||
|
self.token_list = token_list
|
||||||
|
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.n_vocab = vocab_size
|
||||||
|
if (
|
||||||
|
pre_beam_score_key is not None
|
||||||
|
and pre_beam_score_key != "full"
|
||||||
|
and pre_beam_score_key not in self.full_scorers
|
||||||
|
):
|
||||||
|
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
||||||
|
self.pre_beam_score_key = pre_beam_score_key
|
||||||
|
self.do_pre_beam = (
|
||||||
|
self.pre_beam_score_key is not None
|
||||||
|
and self.pre_beam_size < self.n_vocab
|
||||||
|
and len(self.part_scorers) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_hyp(self, x) -> List[Hypothesis]:
|
||||||
|
"""Get an initial hypothesis data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The encoder output feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hypothesis: The initial hypothesis.
|
||||||
|
|
||||||
|
"""
|
||||||
|
init_states = dict()
|
||||||
|
init_scores = dict()
|
||||||
|
for k, d in self.scorers.items():
|
||||||
|
init_states[k] = d.init_state(x)
|
||||||
|
init_scores[k] = 0.0
|
||||||
|
return [
|
||||||
|
Hypothesis(
|
||||||
|
score=0.0,
|
||||||
|
scores=init_scores,
|
||||||
|
states=init_states,
|
||||||
|
yseq=torch.tensor([self.sos], device=x.device),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
||||||
|
"""Append new token to prefix tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (torch.Tensor): The prefix token
|
||||||
|
x (int): The new token to append
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
||||||
|
return torch.cat((xs, x))
|
||||||
|
|
||||||
|
def score_full(
|
||||||
|
self, hyp: Hypothesis,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor = None,
|
||||||
|
pre_acoustic_embeds: torch.Tensor = None,
|
||||||
|
cache: dict={},
|
||||||
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||||
|
"""Score new hypothesis by `self.full_scorers`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||||
|
x (torch.Tensor): Corresponding input feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||||
|
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||||
|
and tensor score values of shape: `(self.n_vocab,)`,
|
||||||
|
and state dict that has string keys
|
||||||
|
and state values of `self.full_scorers`
|
||||||
|
|
||||||
|
"""
|
||||||
|
scores = dict()
|
||||||
|
states = dict()
|
||||||
|
for k, d in self.full_scorers.items():
|
||||||
|
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
|
||||||
|
return scores, states
|
||||||
|
|
||||||
|
def score_partial(
|
||||||
|
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
||||||
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||||
|
"""Score new hypothesis by `self.part_scorers`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||||
|
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
||||||
|
x (torch.Tensor): Corresponding input feature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||||
|
score dict of `hyp` that has string keys of `self.part_scorers`
|
||||||
|
and tensor score values of shape: `(len(ids),)`,
|
||||||
|
and state dict that has string keys
|
||||||
|
and state values of `self.part_scorers`
|
||||||
|
|
||||||
|
"""
|
||||||
|
scores = dict()
|
||||||
|
states = dict()
|
||||||
|
for k, d in self.part_scorers.items():
|
||||||
|
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
||||||
|
return scores, states
|
||||||
|
|
||||||
|
def beam(
|
||||||
|
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute topk full token ids and partial token ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||||
|
Its shape is `(self.n_vocab,)`.
|
||||||
|
ids (torch.Tensor): The partial token ids to compute topk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
The topk full token ids and partial token ids.
|
||||||
|
Their shapes are `(self.beam_size,)`
|
||||||
|
|
||||||
|
"""
|
||||||
|
# no pre beam performed
|
||||||
|
if weighted_scores.size(0) == ids.size(0):
|
||||||
|
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||||
|
return top_ids, top_ids
|
||||||
|
|
||||||
|
# mask pruned in pre-beam not to select in topk
|
||||||
|
tmp = weighted_scores[ids]
|
||||||
|
weighted_scores[:] = -float("inf")
|
||||||
|
weighted_scores[ids] = tmp
|
||||||
|
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||||
|
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
||||||
|
return top_ids, local_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_scores(
|
||||||
|
prev_scores: Dict[str, float],
|
||||||
|
next_full_scores: Dict[str, torch.Tensor],
|
||||||
|
full_idx: int,
|
||||||
|
next_part_scores: Dict[str, torch.Tensor],
|
||||||
|
part_idx: int,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Merge scores for new hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_scores (Dict[str, float]):
|
||||||
|
The previous hypothesis scores by `self.scorers`
|
||||||
|
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
||||||
|
full_idx (int): The next token id for `next_full_scores`
|
||||||
|
next_part_scores (Dict[str, torch.Tensor]):
|
||||||
|
scores of partial tokens by `self.part_scorers`
|
||||||
|
part_idx (int): The new token id for `next_part_scores`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: The new score dict.
|
||||||
|
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||||
|
Its values are scalar tensors by the scorers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_scores = dict()
|
||||||
|
for k, v in next_full_scores.items():
|
||||||
|
new_scores[k] = prev_scores[k] + v[full_idx]
|
||||||
|
for k, v in next_part_scores.items():
|
||||||
|
new_scores[k] = prev_scores[k] + v[part_idx]
|
||||||
|
return new_scores
|
||||||
|
|
||||||
|
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||||
|
"""Merge states for new hypothesis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
states: states of `self.full_scorers`
|
||||||
|
part_states: states of `self.part_scorers`
|
||||||
|
part_idx (int): The new token id for `part_scores`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: The new score dict.
|
||||||
|
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||||
|
Its values are states of the scorers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
new_states = dict()
|
||||||
|
for k, v in states.items():
|
||||||
|
new_states[k] = v
|
||||||
|
for k, d in self.part_scorers.items():
|
||||||
|
new_states[k] = d.select_state(part_states[k], part_idx)
|
||||||
|
return new_states
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, running_hyps: List[Hypothesis],
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor = None,
|
||||||
|
pre_acoustic_embeds: torch.Tensor = None,
|
||||||
|
cache: dict={},
|
||||||
|
) -> List[Hypothesis]:
|
||||||
|
"""Search new tokens for running hypotheses and encoded speech x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
||||||
|
x (torch.Tensor): Encoded speech feature (T, D)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Hypotheses]: Best sorted hypotheses
|
||||||
|
|
||||||
|
"""
|
||||||
|
best_hyps = []
|
||||||
|
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
||||||
|
for hyp in running_hyps:
|
||||||
|
# scoring
|
||||||
|
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
||||||
|
scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
|
||||||
|
for k in self.full_scorers:
|
||||||
|
weighted_scores += self.weights[k] * scores[k]
|
||||||
|
# partial scoring
|
||||||
|
if self.do_pre_beam:
|
||||||
|
pre_beam_scores = (
|
||||||
|
weighted_scores
|
||||||
|
if self.pre_beam_score_key == "full"
|
||||||
|
else scores[self.pre_beam_score_key]
|
||||||
|
)
|
||||||
|
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
||||||
|
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
||||||
|
for k in self.part_scorers:
|
||||||
|
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
||||||
|
# add previous hyp score
|
||||||
|
weighted_scores += hyp.score
|
||||||
|
|
||||||
|
# update hyps
|
||||||
|
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
||||||
|
# will be (2 x beam at most)
|
||||||
|
best_hyps.append(
|
||||||
|
Hypothesis(
|
||||||
|
score=weighted_scores[j],
|
||||||
|
yseq=self.append_token(hyp.yseq, j),
|
||||||
|
scores=self.merge_scores(
|
||||||
|
hyp.scores, scores, j, part_scores, part_j
|
||||||
|
),
|
||||||
|
states=self.merge_states(states, part_states, part_j),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# sort and prune 2 x beam -> beam
|
||||||
|
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
||||||
|
: min(len(best_hyps), self.beam_size)
|
||||||
|
]
|
||||||
|
return best_hyps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor,
|
||||||
|
scama_mask: torch.Tensor = None,
|
||||||
|
pre_acoustic_embeds: torch.Tensor = None,
|
||||||
|
maxlenratio: float = 0.0,
|
||||||
|
minlenratio: float = 0.0,
|
||||||
|
maxlen: int = None,
|
||||||
|
minlen: int = 0,
|
||||||
|
cache:dict={},
|
||||||
|
) -> List[Hypothesis]:
|
||||||
|
"""Perform beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Encoded speech feature (T, D)
|
||||||
|
maxlenratio (float): Input length ratio to obtain max output length.
|
||||||
|
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||||
|
to automatically find maximum hypothesis lengths
|
||||||
|
If maxlenratio<0.0, its absolute value is interpreted
|
||||||
|
as a constant max output length.
|
||||||
|
minlenratio (float): Input length ratio to obtain min output length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Hypothesis]: N-best decoding results
|
||||||
|
|
||||||
|
"""
|
||||||
|
if maxlen is None:
|
||||||
|
# set length bounds
|
||||||
|
if maxlenratio == 0:
|
||||||
|
maxlen = x.shape[0]
|
||||||
|
elif maxlenratio < 0:
|
||||||
|
maxlen = -1 * int(maxlenratio)
|
||||||
|
else:
|
||||||
|
maxlen = max(1, int(maxlenratio * x.size(0)))
|
||||||
|
minlen = int(minlenratio * x.size(0))
|
||||||
|
|
||||||
|
logging.info("decoder input length: " + str(x.shape[0]))
|
||||||
|
logging.info("max output length: " + str(maxlen))
|
||||||
|
logging.info("min output length: " + str(minlen))
|
||||||
|
|
||||||
|
# main loop of prefix search
|
||||||
|
# running_hyps = self.init_hyp(x)
|
||||||
|
running_hyps = cache["running_hyps"]
|
||||||
|
ended_hyps = []
|
||||||
|
for i in range(maxlen):
|
||||||
|
logging.debug("position " + str(i))
|
||||||
|
mask_enc = None
|
||||||
|
# if scama_mask is not None:
|
||||||
|
# token_num_predictor = scama_mask.size(1)
|
||||||
|
# token_id_slice = min(i, token_num_predictor-1)
|
||||||
|
# mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
|
||||||
|
# # if mask_enc.size(1) == 0:
|
||||||
|
# # mask_enc = scama_mask[:, -2:-1, :]
|
||||||
|
# # # mask_enc = torch.zeros_like(mask_enc)
|
||||||
|
pre_acoustic_embeds_cur = None
|
||||||
|
if pre_acoustic_embeds is not None:
|
||||||
|
b, t, d = pre_acoustic_embeds.size()
|
||||||
|
pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
|
||||||
|
pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
|
||||||
|
token_id_slice = min(i, t)
|
||||||
|
pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
|
||||||
|
|
||||||
|
best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"])
|
||||||
|
# post process of one iteration
|
||||||
|
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
||||||
|
# end detection
|
||||||
|
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
||||||
|
logging.info(f"end detected at {i}")
|
||||||
|
break
|
||||||
|
if len(running_hyps) == 0:
|
||||||
|
logging.info("no hypothesis. Finish decoding.")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||||
|
|
||||||
|
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||||
|
# check the number of hypotheses reaching to eos
|
||||||
|
if len(nbest_hyps) == 0:
|
||||||
|
logging.warning(
|
||||||
|
"there is no N-best results, perform recognition "
|
||||||
|
"again with smaller minlenratio."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
[]
|
||||||
|
if minlenratio < 0.1
|
||||||
|
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
||||||
|
)
|
||||||
|
|
||||||
|
# report the best result
|
||||||
|
for x in nbest_hyps:
|
||||||
|
yseq = "".join([self.token_list[x] for x in x.yseq])
|
||||||
|
logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
|
||||||
|
best = nbest_hyps[0]
|
||||||
|
for k, v in best.scores.items():
|
||||||
|
logging.info(
|
||||||
|
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
||||||
|
)
|
||||||
|
logging.info(f"total log probability: {best.score:.2f}")
|
||||||
|
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
||||||
|
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||||
|
if self.token_list is not None:
|
||||||
|
logging.info(
|
||||||
|
"best hypo: "
|
||||||
|
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
return nbest_hyps
|
||||||
|
|
||||||
|
def post_process(
|
||||||
|
self,
|
||||||
|
i: int,
|
||||||
|
maxlen: int,
|
||||||
|
maxlenratio: float,
|
||||||
|
running_hyps: List[Hypothesis],
|
||||||
|
ended_hyps: List[Hypothesis],
|
||||||
|
) -> List[Hypothesis]:
|
||||||
|
"""Perform post-processing of beam search iterations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
i (int): The length of hypothesis tokens.
|
||||||
|
maxlen (int): The maximum length of tokens in beam search.
|
||||||
|
maxlenratio (int): The maximum length ratio in beam search.
|
||||||
|
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
||||||
|
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Hypothesis]: The new running hypotheses.
|
||||||
|
|
||||||
|
"""
|
||||||
|
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
||||||
|
if self.token_list is not None:
|
||||||
|
logging.debug(
|
||||||
|
"best hypo: "
|
||||||
|
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
||||||
|
)
|
||||||
|
# add eos in the final loop to avoid that there are no ended hyps
|
||||||
|
if i == maxlen - 1:
|
||||||
|
logging.info("adding <eos> in the last position in the loop")
|
||||||
|
running_hyps = [
|
||||||
|
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
||||||
|
for h in running_hyps
|
||||||
|
]
|
||||||
|
|
||||||
|
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||||
|
# (this will be a problem, number of hyps < beam)
|
||||||
|
remained_hyps = []
|
||||||
|
for hyp in running_hyps:
|
||||||
|
if hyp.yseq[-1] == self.eos:
|
||||||
|
# e.g., Word LM needs to add final <eos> score
|
||||||
|
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
||||||
|
s = d.final_score(hyp.states[k])
|
||||||
|
hyp.scores[k] += s
|
||||||
|
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
||||||
|
ended_hyps.append(hyp)
|
||||||
|
else:
|
||||||
|
remained_hyps.append(hyp)
|
||||||
|
return remained_hyps
|
||||||
|
|||||||
@ -436,7 +436,10 @@ class SCAMA(nn.Module):
|
|||||||
def init_beam_search(self,
|
def init_beam_search(self,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from funasr.models.scama.beam_search import BeamSearchScama
|
|
||||||
|
from funasr.models.scama.beam_search import BeamSearchScamaStreaming
|
||||||
|
|
||||||
|
|
||||||
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
||||||
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||||||
|
|
||||||
@ -460,13 +463,14 @@ class SCAMA(nn.Module):
|
|||||||
scorers["ngram"] = ngram
|
scorers["ngram"] = ngram
|
||||||
|
|
||||||
weights = dict(
|
weights = dict(
|
||||||
decoder=1.0 - kwargs.get("decoding_ctc_weight"),
|
decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
|
||||||
ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
||||||
lm=kwargs.get("lm_weight", 0.0),
|
lm=kwargs.get("lm_weight", 0.0),
|
||||||
ngram=kwargs.get("ngram_weight", 0.0),
|
ngram=kwargs.get("ngram_weight", 0.0),
|
||||||
length_bonus=kwargs.get("penalty", 0.0),
|
length_bonus=kwargs.get("penalty", 0.0),
|
||||||
)
|
)
|
||||||
beam_search = BeamSearchScama(
|
|
||||||
|
beam_search = BeamSearchScamaStreaming(
|
||||||
beam_size=kwargs.get("beam_size", 2),
|
beam_size=kwargs.get("beam_size", 2),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
scorers=scorers,
|
scorers=scorers,
|
||||||
@ -499,7 +503,11 @@ class SCAMA(nn.Module):
|
|||||||
is_final=kwargs.get("is_final", False))
|
is_final=kwargs.get("is_final", False))
|
||||||
if isinstance(encoder_out, tuple):
|
if isinstance(encoder_out, tuple):
|
||||||
encoder_out = encoder_out[0]
|
encoder_out = encoder_out[0]
|
||||||
|
if "running_hyps" not in cache:
|
||||||
|
running_hyps = self.beam_search.init_hyp(encoder_out)
|
||||||
|
cache["running_hyps"] = running_hyps
|
||||||
|
|
||||||
|
|
||||||
# predictor
|
# predictor
|
||||||
predictor_outs = self.calc_predictor_chunk(encoder_out,
|
predictor_outs = self.calc_predictor_chunk(encoder_out,
|
||||||
encoder_out_lens,
|
encoder_out_lens,
|
||||||
@ -513,47 +521,30 @@ class SCAMA(nn.Module):
|
|||||||
|
|
||||||
if torch.max(pre_token_length) < 1:
|
if torch.max(pre_token_length) < 1:
|
||||||
return []
|
return []
|
||||||
decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
|
maxlen = minlen = pre_token_length
|
||||||
encoder_out_lens,
|
if kwargs.get("is_final", False):
|
||||||
pre_acoustic_embeds,
|
maxlen += kwargs.get("token_num_relax", 5)
|
||||||
pre_token_length,
|
minlen = max(0, minlen - kwargs.get("token_num_relax", 5))
|
||||||
cache=cache
|
# c. Passed the encoder result and the beam search
|
||||||
)
|
nbest_hyps = self.beam_search(
|
||||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache["running_hyps"] = nbest_hyps
|
||||||
|
nbest_hyps = nbest_hyps[: self.nbest]
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
b, n, d = decoder_out.size()
|
for hyp in nbest_hyps:
|
||||||
if isinstance(key[0], (list, tuple)):
|
# assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||||
key = key[0]
|
|
||||||
for i in range(b):
|
# remove sos/eos and get results
|
||||||
x = encoder_out[i, :encoder_out_lens[i], :]
|
last_pos = -1
|
||||||
am_scores = decoder_out[i, :pre_token_length[i], :]
|
if isinstance(hyp.yseq, list):
|
||||||
if self.beam_search is not None:
|
token_int = hyp.yseq[1:last_pos]
|
||||||
nbest_hyps = self.beam_search(
|
|
||||||
x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
|
|
||||||
minlenratio=kwargs.get("minlenratio", 0.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
nbest_hyps = nbest_hyps[: self.nbest]
|
|
||||||
else:
|
else:
|
||||||
|
token_int = hyp.yseq[1:last_pos].tolist()
|
||||||
yseq = am_scores.argmax(dim=-1)
|
|
||||||
score = am_scores.max(dim=-1)[0]
|
|
||||||
score = torch.sum(score, dim=-1)
|
|
||||||
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
|
||||||
yseq = torch.tensor(
|
|
||||||
[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
|
|
||||||
)
|
|
||||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
|
||||||
for nbest_idx, hyp in enumerate(nbest_hyps):
|
|
||||||
|
|
||||||
# remove sos/eos and get results
|
|
||||||
last_pos = -1
|
|
||||||
if isinstance(hyp.yseq, list):
|
|
||||||
token_int = hyp.yseq[1:last_pos]
|
|
||||||
else:
|
|
||||||
token_int = hyp.yseq[1:last_pos].tolist()
|
|
||||||
|
|
||||||
# remove blank symbol id, which is assumed to be 0
|
# remove blank symbol id, which is assumed to be 0
|
||||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||||
|
|
||||||
@ -568,6 +559,8 @@ class SCAMA(nn.Module):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def init_cache(self, cache: dict = {}, **kwargs):
|
def init_cache(self, cache: dict = {}, **kwargs):
|
||||||
|
device = kwargs.get("device", "cuda")
|
||||||
|
|
||||||
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
|
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
|
||||||
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
|
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
|
||||||
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
|
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
|
||||||
@ -575,10 +568,11 @@ class SCAMA(nn.Module):
|
|||||||
|
|
||||||
enc_output_size = kwargs["encoder_conf"]["output_size"]
|
enc_output_size = kwargs["encoder_conf"]["output_size"]
|
||||||
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
|
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
|
||||||
cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
|
|
||||||
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
|
cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
|
||||||
|
"cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size,
|
||||||
"encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
|
"encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
|
||||||
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
|
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device),
|
||||||
"tail_chunk": False}
|
"tail_chunk": False}
|
||||||
cache["encoder"] = cache_encoder
|
cache["encoder"] = cache_encoder
|
||||||
|
|
||||||
@ -586,8 +580,10 @@ class SCAMA(nn.Module):
|
|||||||
"chunk_size": chunk_size}
|
"chunk_size": chunk_size}
|
||||||
cache["decoder"] = cache_decoder
|
cache["decoder"] = cache_decoder
|
||||||
cache["frontend"] = {}
|
cache["frontend"] = {}
|
||||||
cache["prev_samples"] = torch.empty(0)
|
|
||||||
|
|
||||||
|
cache["prev_samples"] = torch.empty(0).to(device=device)
|
||||||
|
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def inference(self,
|
def inference(self,
|
||||||
@ -603,7 +599,10 @@ class SCAMA(nn.Module):
|
|||||||
# init beamsearch
|
# init beamsearch
|
||||||
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
|
||||||
|
if self.beam_search is None:
|
||||||
|
|
||||||
|
|
||||||
logging.info("enable beam_search")
|
logging.info("enable beam_search")
|
||||||
self.init_beam_search(**kwargs)
|
self.init_beam_search(**kwargs)
|
||||||
self.nbest = kwargs.get("nbest", 1)
|
self.nbest = kwargs.get("nbest", 1)
|
||||||
|
|||||||
@ -148,6 +148,7 @@ class Trainer:
|
|||||||
|
|
||||||
self._train_epoch(epoch)
|
self._train_epoch(epoch)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@ -156,8 +157,8 @@ class Trainer:
|
|||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self._save_checkpoint(epoch)
|
self._save_checkpoint(epoch)
|
||||||
|
|
||||||
@ -172,7 +173,8 @@ class Trainer:
|
|||||||
|
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
if self.writer:
|
if self.writer:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user