From a75bbb028e5966ddf02aae5bea05909be9a99826 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 11 Jan 2024 17:36:30 +0800 Subject: [PATCH] funasr1.0 paraformer_streaming --- .../paraformer_streaming/demo.py | 50 ++++++----- .../paraformer_streaming/finetune.sh | 14 ---- .../paraformer_streaming/infer.sh | 2 +- funasr/models/paraformer/cif_predictor.py | 11 +-- funasr/models/paraformer_streaming/model.py | 82 +++++++++++-------- funasr/models/scama/sanm_encoder.py | 2 + funasr/utils/load_utils.py | 2 +- 7 files changed, 85 insertions(+), 78 deletions(-) delete mode 100644 examples/industrial_data_pretraining/paraformer_streaming/finetune.sh diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index 0036e77e1..9923a0445 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -3,36 +3,44 @@ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) -# from funasr import AutoModel -# -# model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revison="v2.0.0") -# -# res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") -# print(res) +from funasr import AutoModel +chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms +encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention +decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention -from funasr import AutoFrontend - -frontend = AutoFrontend(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0") - +model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0") +cache = {} +res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", + cache=cache, + is_final=True, + chunk_size=chunk_size, + encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, + ) +print(res) import soundfile -speech, sample_rate = soundfile.read("/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav") +import os + +speech, sample_rate = soundfile.read(os.path.expanduser('~')+ + "/.cache/modelscope/hub/damo/"+ + "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/"+ + "example/asr_example.wav") -chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms chunk_stride = chunk_size[1] * 960 # 600ms、480ms -# first chunk, 600ms cache = {} for i in range(int(len((speech)-1)/chunk_stride+1)): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] - fbanks = frontend(input=speech_chunk, - batch_size=2, - cache=cache) - - -# for batch_idx, fbank_dict in enumerate(fbanks): -# res = model(**fbank_dict) -# print(res) \ No newline at end of file + is_final = i == int(len((speech)-1)/chunk_stride+1) + res = model(input=speech_chunk, + cache=cache, + is_final=is_final, + chunk_size=chunk_size, + encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, + ) + print(res) diff --git a/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh b/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh deleted file mode 100644 index 6dca09f83..000000000 --- a/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh +++ /dev/null @@ -1,14 +0,0 @@ - -# download model -local_path_root=../modelscope_models -mkdir -p ${local_path_root} -local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch -git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path} - - -python funasr/bin/train.py \ -+model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ -+token_list="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \ -+train_data_set_list="data/list/audio_datasets.jsonl" \ -+output_dir="outputs/debug/ckpt/funasr2/exp2" \ -+device="cpu" \ No newline at end of file diff --git a/examples/industrial_data_pretraining/paraformer_streaming/infer.sh b/examples/industrial_data_pretraining/paraformer_streaming/infer.sh index 9436628b7..77e839b66 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/infer.sh +++ b/examples/industrial_data_pretraining/paraformer_streaming/infer.sh @@ -1,5 +1,5 @@ -model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online" model_revision="v2.0.0" python funasr/bin/inference.py \ diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 383d9ca7c..b06fa43eb 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -205,7 +205,8 @@ class CifPredictorV2(nn.Module): return acoustic_embeds, token_num, alphas, cif_peak - def forward_chunk(self, hidden, cache=None): + def forward_chunk(self, hidden, cache=None, **kwargs): + is_final = kwargs.get("is_final", False) batch_size, len_time, hidden_size = hidden.shape h = hidden context = h.transpose(1, 2) @@ -226,14 +227,14 @@ class CifPredictorV2(nn.Module): if cache is not None and "chunk_size" in cache: alphas[:, :cache["chunk_size"][0]] = 0.0 - if "is_final" in cache and not cache["is_final"]: + if not is_final: alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) - if cache is not None and "is_final" in cache and cache["is_final"]: + if cache is not None and is_final: tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) tail_alphas = torch.tile(tail_alphas, (batch_size, 1)) @@ -277,7 +278,7 @@ class CifPredictorV2(nn.Module): max_token_len = max(token_length) if max_token_len == 0: - return hidden, torch.stack(token_length, 0) + return hidden, torch.stack(token_length, 0), None, None list_ls = [] for b in range(batch_size): pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device) @@ -291,7 +292,7 @@ class CifPredictorV2(nn.Module): cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) - return torch.stack(list_ls, 0), torch.stack(token_length, 0) + return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index 304c0f7e7..927b09199 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -64,8 +64,8 @@ class ParaformerStreaming(Paraformer): super().__init__(*args, **kwargs) - import pdb; - pdb.set_trace() + # import pdb; + # pdb.set_trace() self.sampling_ratio = kwargs.get("sampling_ratio", 0.2) @@ -375,11 +375,10 @@ class ParaformerStreaming(Paraformer): return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index - def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None): - - pre_acoustic_embeds, pre_token_length = \ - self.predictor.forward_chunk(encoder_out, cache["encoder"]) - return pre_acoustic_embeds, pre_token_length + def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs): + is_final = kwargs.get("is_final", False) + + return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final) def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): decoder_outs = self.decoder( @@ -416,7 +415,7 @@ class ParaformerStreaming(Paraformer): "chunk_size": chunk_size} cache["decoder"] = cache_decoder cache["frontend"] = {} - cache["prev_samples"] = [] + cache["prev_samples"] = torch.empty(0) return cache @@ -432,12 +431,12 @@ class ParaformerStreaming(Paraformer): speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) # Encoder - encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache) + encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] # predictor - predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache) + predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False)) pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ predictor_outs[2], predictor_outs[3] pre_token_length = pre_token_length.round().long() @@ -476,10 +475,7 @@ class ParaformerStreaming(Paraformer): ) nbest_hyps = [Hypothesis(yseq=yseq, score=score)] for nbest_idx, hyp in enumerate(nbest_hyps): - ibest_writer = None - if ibest_writer is None and kwargs.get("output_dir") is not None: - writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = writer[f"{nbest_idx + 1}best_recog"] + # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): @@ -490,22 +486,15 @@ class ParaformerStreaming(Paraformer): # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) - if tokenizer is not None: - # Change integer-ids to tokens - token = tokenizer.ids2tokens(token_int) - text = tokenizer.tokens2text(token) - - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - - result_i = {"key": key[i], "text": text_postprocessed} - - if ibest_writer is not None: - ibest_writer["token"][key[i]] = " ".join(token) - # ibest_writer["text"][key[i]] = text - ibest_writer["text"][key[i]] = text_postprocessed - else: - result_i = {"key": key[i], "token_int": token_int} - results.append(result_i) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + # text = tokenizer.tokens2text(token) + + result_i = token + + + results.extend(result_i) return results @@ -515,6 +504,7 @@ class ParaformerStreaming(Paraformer): key: list = None, tokenizer=None, frontend=None, + cache: dict={}, **kwargs, ): @@ -526,9 +516,10 @@ class ParaformerStreaming(Paraformer): self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - cache = kwargs.get("cache", {}) + if len(cache) == 0: self.init_cache(cache, **kwargs) + _is_final = kwargs.get("is_final", False) meta_data = {} chunk_size = kwargs.get("chunk_size", [0, 10, 5]) @@ -542,22 +533,41 @@ class ParaformerStreaming(Paraformer): meta_data["load_data"] = f"{time2 - time1:0.3f}" assert len(audio_sample_list) == 1, "batch_size must be set 1" - audio_sample = cache["prev_samples"] + audio_sample_list[0] + audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) - n = len(audio_sample) // chunk_stride_samples - m = len(audio_sample) % chunk_stride_samples + n = len(audio_sample) // chunk_stride_samples + int(_is_final) + m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final)) + tokens = [] for i in range(n): + kwargs["is_final"] = _is_final and i == n -1 audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples] # extract fbank feats speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), - frontend=frontend, cache=cache["frontend"]) + frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"]) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - result_i = self.generate_chunk(speech, speech_lengths, **kwargs) + tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs) + tokens.extend(tokens_i) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens) + + result_i = {"key": key[0], "text": text_postprocessed} + result = [result_i] + cache["prev_samples"] = audio_sample[:-m] + if _is_final: + self.init_cache(cache, **kwargs) + + if kwargs.get("output_dir"): + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{1}best_recog"] + ibest_writer["token"][key[0]] = " ".join(tokens) + ibest_writer["text"][key[0]] = text_postprocessed + + return result, meta_data diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/sanm_encoder.py index 4bf6ef0ed..5e28db7df 100644 --- a/funasr/models/scama/sanm_encoder.py +++ b/funasr/models/scama/sanm_encoder.py @@ -423,7 +423,9 @@ class SANMEncoderChunkOpt(nn.Module): xs_pad: torch.Tensor, ilens: torch.Tensor, cache: dict = None, + **kwargs, ): + is_final = kwargs.get("is_final", False) xs_pad *= self.output_size() ** 0.5 if self.embed is None: xs_pad = xs_pad diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index 39b708a68..bb9cf01b9 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -43,7 +43,7 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None: data_or_path_or_list = tokenizer.encode(data_or_path_or_list) elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point - data_or_path_or_list = np.squeeze(data_or_path_or_list) # [n_samples,] + data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,] else: pass # print(f"unsupport data type: {data_or_path_or_list}, return raw data")