diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py index 01c6c39fb..6761a807a 100644 --- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py @@ -4,7 +4,8 @@ # MIT License (https://opensource.org/licenses/MIT) from funasr import AutoModel -wav_file = "/Users/zhifu/funasr_github/test_local/asr_example.wav" +wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" + chunk_size = 60000 # ms model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-streaming", model_revision="v2.0.0") @@ -14,23 +15,23 @@ res = model(input=wav_file, print(res) -# -# import soundfile -# import os -# -# # wav_file = os.path.join(model.model_path, "example/vad_example.wav") -# speech, sample_rate = soundfile.read(wav_file) -# -# chunk_stride = int(chunk_size * 16000 / 1000) -# -# cache = {} -# -# for i in range(int(len((speech)-1)/chunk_stride+1)): -# speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] -# is_final = i == int(len((speech)-1)/chunk_stride+1) -# res = model(input=speech_chunk, -# cache=cache, -# is_final=is_final, -# chunk_size=chunk_size, -# ) -# print(res) + +import soundfile +import os + +wav_file = os.path.join(model.model_path, "example/vad_example.wav") +speech, sample_rate = soundfile.read(wav_file) + +chunk_stride = int(chunk_size * 16000 / 1000) + +cache = {} + +for i in range(int(len((speech)-1)/chunk_stride+1)): + speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + is_final = i == int(len((speech)-1)/chunk_stride+1) - 1 + res = model(input=speech_chunk, + cache=cache, + is_final=is_final, + chunk_size=chunk_size, + ) + print(res) diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index d4dd34ed3..872843b8e 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -31,7 +31,7 @@ cache = {} for i in range(int(len((speech)-1)/chunk_stride+1)): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] - is_final = i == int(len((speech)-1)/chunk_stride+1) + is_final = i == int(len((speech)-1)/chunk_stride+1) - 1 res = model(input=speech_chunk, cache=cache, is_final=is_final, diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index fe223357f..f4100859e 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -448,8 +448,8 @@ class WavFrontendOnline(nn.Module): feats = torch.stack(cache["lfr_splice_cache"]) feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1] feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache) - if is_final: - self.init_cache(cache) + # if is_final: + # self.init_cache(cache) return feats, feats_lengths diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py index 75c6f4a93..15d2af52e 100644 --- a/funasr/models/fsmn_vad/model.py +++ b/funasr/models/fsmn_vad/model.py @@ -12,6 +12,7 @@ from funasr.register import tables from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank from funasr.utils.datadir_writer import DatadirWriter from torch.nn.utils.rnn import pad_sequence +from funasr.train_utils.device_funcs import to_device class VadStateMachine(Enum): kVadInStateStartPointNotDetected = 1 @@ -579,7 +580,8 @@ class FsmnVAD(nn.Module): "cache": cache } - + + batch = to_device(batch, device=kwargs["device"]) segments_part, cache = self.forward(**batch) if segments_part: for batch_num in range(0, batch_size): diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 13b3f3a4c..e0d104a51 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -587,7 +587,6 @@ class FsmnVADStreaming(nn.Module): "cache": cache["encoder"] } segments_i = self.forward(**batch) - print(segments_i) segments.extend(segments_i)