From 40d1f80030d38b3377a95ead8837e82c67aa59f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 12 Jan 2024 12:05:08 +0800 Subject: [PATCH] funasr1.0 streaming demo --- .../fsmn_vad_streaming/demo.py | 31 ++++++- funasr/models/fsmn_vad/model.py | 7 +- funasr/models/fsmn_vad_streaming/model.py | 93 ++++++------------- funasr/models/paraformer_streaming/model.py | 8 +- funasr/utils/load_utils.py | 3 +- 5 files changed, 67 insertions(+), 75 deletions(-) diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py index 2a157ee23..01c6c39fb 100644 --- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py @@ -4,8 +4,33 @@ # MIT License (https://opensource.org/licenses/MIT) from funasr import AutoModel +wav_file = "/Users/zhifu/funasr_github/test_local/asr_example.wav" +chunk_size = 60000 # ms +model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-streaming", model_revision="v2.0.0") -model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0") +res = model(input=wav_file, + chunk_size=chunk_size, + ) +print(res) -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav") -print(res) \ No newline at end of file + +# +# import soundfile +# import os +# +# # wav_file = os.path.join(model.model_path, "example/vad_example.wav") +# speech, sample_rate = soundfile.read(wav_file) +# +# chunk_stride = int(chunk_size * 16000 / 1000) +# +# cache = {} +# +# for i in range(int(len((speech)-1)/chunk_stride+1)): +# speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] +# is_final = i == int(len((speech)-1)/chunk_stride+1) +# res = model(input=speech_chunk, +# cache=cache, +# is_final=is_final, +# chunk_size=chunk_size, +# ) +# print(res) diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py index 1ed077394..75c6f4a93 100644 --- a/funasr/models/fsmn_vad/model.py +++ b/funasr/models/fsmn_vad/model.py @@ -593,15 +593,16 @@ class FsmnVAD(nn.Module): results = [] for i in range(batch_size): - if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": - results[i] = json.dumps(results[i]) if ibest_writer is not None: ibest_writer["text"][key[i]] = segments[i] result_i = {"key": key[i], "value": segments[i]} results.append(result_i) - + + if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": + results[i] = json.dumps(results[i]) + return results, meta_data def DetectCommonFrames(self) -> int: diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 4c7e94309..13b3f3a4c 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -496,7 +496,7 @@ class FsmnVADStreaming(nn.Module): def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False ): - if not cache: + if len(cache) == 0: self.AllResetDetection() self.waveform = waveform # compute decibel for each frame self.ComputeDecibel() @@ -521,13 +521,15 @@ class FsmnVADStreaming(nn.Module): if is_final: # reset class variables and clear the dict for the next query self.AllResetDetection() - return segments, cache + return segments def init_cache(self, cache: dict = {}, **kwargs): cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) + cache["encoder"] = {} return cache + def generate(self, data_in, data_lengths=None, @@ -543,7 +545,7 @@ class FsmnVADStreaming(nn.Module): meta_data = {} chunk_size = kwargs.get("chunk_size", 50) # 50ms - chunk_stride_samples = chunk_size * 16 + chunk_stride_samples = int(chunk_size * frontend.fs / 1000) time1 = time.perf_counter() cfg = {"is_final": kwargs.get("is_final", False)} @@ -552,7 +554,7 @@ class FsmnVADStreaming(nn.Module): audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer, - **cfg, + cache=cfg, ) _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True @@ -562,9 +564,9 @@ class FsmnVADStreaming(nn.Module): audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) - n = len(audio_sample) // chunk_stride_samples + int(_is_final) - m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)) - tokens = [] + n = int(len(audio_sample) // chunk_stride_samples + int(_is_final)) + m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))) + segments = [] for i in range(n): kwargs["is_final"] = _is_final and i == n - 1 audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples] @@ -576,58 +578,22 @@ class FsmnVADStreaming(nn.Module): time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - - meta_data = {} - audio_sample_list = [data_in] - if isinstance(data_in, torch.Tensor): # fbank - speech, speech_lengths = data_in, data_lengths - if len(speech.shape) < 3: - speech = speech[None, :, :] - if speech_lengths is None: - speech_lengths = speech.shape[1] - else: - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data[ - "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - - speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) - - # b. Forward Encoder streaming - t_offset = 0 - feats = speech - feats_len = speech_lengths.max().item() - waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N] - cache = kwargs.get("cache", {}) - batch_size = kwargs.get("batch_size", 1) - step = min(feats_len, 6000) - segments = [[]] * batch_size - - for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): - if t_offset + step >= feats_len - 1: - step = feats_len - t_offset - is_final = True - else: - is_final = False + speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + batch = { - "feats": feats[:, t_offset:t_offset + step, :], - "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)], - "is_final": is_final, - "cache": cache + "feats": speech, + "waveform": cache["frontend"]["waveforms"], + "is_final": kwargs["is_final"], + "cache": cache["encoder"] } + segments_i = self.forward(**batch) + print(segments_i) + segments.extend(segments_i) - segments_part, cache = self.forward(**batch) - if segments_part: - for batch_num in range(0, batch_size): - segments[batch_num] += segments_part[batch_num] + cache["prev_samples"] = audio_sample[:-m] + if _is_final: + self.init_cache(cache, **kwargs) ibest_writer = None if ibest_writer is None and kwargs.get("output_dir") is not None: @@ -635,16 +601,15 @@ class FsmnVADStreaming(nn.Module): ibest_writer = writer[f"{1}best_recog"] results = [] - for i in range(batch_size): - - if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": - results[i] = json.dumps(results[i]) - - if ibest_writer is not None: - ibest_writer["text"][key[i]] = segments[i] + result_i = {"key": key[0], "value": segments} + if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": + result_i = json.dumps(result_i) + + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["text"][key[0]] = segments - result_i = {"key": key[i], "value": segments[i]} - results.append(result_i) return results, meta_data diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index fdc0c9312..b736aa9b4 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -523,7 +523,7 @@ class ParaformerStreaming(Paraformer): meta_data = {} chunk_size = kwargs.get("chunk_size", [0, 10, 5]) - chunk_stride_samples = chunk_size[1] * 960 # 600ms + chunk_stride_samples = int(chunk_size[1] * 960) # 600ms time1 = time.perf_counter() cfg = {"is_final": kwargs.get("is_final", False)} @@ -532,7 +532,7 @@ class ParaformerStreaming(Paraformer): audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer, - **cfg, + cache=cfg, ) _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True @@ -542,8 +542,8 @@ class ParaformerStreaming(Paraformer): audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) - n = len(audio_sample) // chunk_stride_samples + int(_is_final) - m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final)) + n = int(len(audio_sample) // chunk_stride_samples + int(_is_final)) + m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final))) tokens = [] for i in range(n): kwargs["is_final"] = _is_final and i == n -1 diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index 638e0ac4f..4e131a84e 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -48,7 +48,8 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: pass # if data_in is a file or url, set is_final=True - kwargs["is_final"] = True + if "cache" in kwargs: + kwargs["cache"]["is_final"] = True elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None: data_or_path_or_list = tokenizer.encode(data_or_path_or_list) elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point