From 3762d21300e1f3fa3e0cb1e67545227e6dcec3de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=81=E8=BF=B7?= Date: Mon, 13 Mar 2023 22:02:54 +0800 Subject: [PATCH 1/4] add streaming paraformer code --- funasr/bin/asr_inference_launch.py | 3 + .../bin/asr_inference_paraformer_streaming.py | 907 ++++++++++++++++++ funasr/models/decoder/sanm_decoder.py | 59 ++ funasr/models/e2e_asr_paraformer.py | 74 +- funasr/models/encoder/sanm_encoder.py | 42 + funasr/models/predictor/cif.py | 57 ++ funasr/modules/attention.py | 10 +- funasr/modules/embedding.py | 11 +- 8 files changed, 1157 insertions(+), 6 deletions(-) create mode 100644 funasr/bin/asr_inference_paraformer_streaming.py diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 1fae766ea..da1241a66 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -216,6 +216,9 @@ def inference_launch(**kwargs): elif mode == "paraformer": from funasr.bin.asr_inference_paraformer import inference_modelscope return inference_modelscope(**kwargs) + elif mode == "paraformer_streaming": + from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope + return inference_modelscope(**kwargs) elif mode == "paraformer_vad": from funasr.bin.asr_inference_paraformer_vad import inference_modelscope return inference_modelscope(**kwargs) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py new file mode 100644 index 000000000..9b572a0af --- /dev/null +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -0,0 +1,907 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys +import time +import copy +import os +import codecs +import tempfile +import requests +from pathlib import Path +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict +from typing import Any +from typing import List + +import numpy as np +import torch +from typeguard import check_argument_types + +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch +from funasr.modules.beam_search.beam_search import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.asr import ASRTaskParaformer as ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + frontend_conf: dict = None, + hotword_list_or_file: str = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + if asr_model.ctc != None: + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + scorers.update( + ctc=ctc + ) + token_list = asr_model.token_list + scorers.update( + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + + # 6. [Optional] Build hotword list from str, local file or url + + is_use_lm = lm_weight != 0.0 and lm_file is not None + if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: + beam_search = None + self.beam_search = beam_search + logging.info(f"Beam_search: {self.beam_search}") + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + self.encoder_downsampling_factor = 1 + if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": + self.encoder_downsampling_factor = 4 + + @torch.no_grad() + def __call__( + self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, + begin_time: int = 0, end_time: int = None, + ): + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, enc_len = self.asr_model.encode_chunk(**batch) + if isinstance(enc, tuple): + enc = enc[0] + # assert len(enc) == 1, len(enc) + enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor + + predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ + predictor_outs[2], predictor_outs[3] + pre_token_length = pre_token_length.floor().long() + if torch.max(pre_token_length) < 1: + return [] + decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache) + decoder_out = decoder_outs + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = enc[i, :enc_len[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + + results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + + # assert check_return_type(results) + return results + + +class Speech2TextExport: + """Speech2TextExport class + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + frontend_conf: dict = None, + hotword_list_or_file: str = None, + **kwargs, + ): + + # 1. Build ASR model + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + token_list = asr_model.token_list + + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + # self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + + model = Paraformer_export(asr_model, onnx=False) + self.asr_model = model + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None + ): + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + + enc_len_batch_total = feats_len.sum() + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + batch = {"speech": feats, "speech_lengths": feats_len} + + # a. To device + batch = to_device(batch, device=self.device) + + decoder_outs = self.asr_model(**batch) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + am_scores = decoder_out[i, :ys_pad_lens[i], :] + + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + yseq.tolist(), device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + + results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + + return results + + +def inference( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + streaming: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + + **kwargs, +): + inference_pipeline = inference_modelscope( + maxlenratio=maxlenratio, + minlenratio=minlenratio, + batch_size=batch_size, + beam_size=beam_size, + ngpu=ngpu, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + penalty=penalty, + log_level=log_level, + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + raw_inputs=raw_inputs, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + key_file=key_file, + word_lm_train_config=word_lm_train_config, + bpemodel=bpemodel, + allow_variable_data_keys=allow_variable_data_keys, + streaming=streaming, + output_dir=output_dir, + dtype=dtype, + seed=seed, + ngram_weight=ngram_weight, + nbest=nbest, + num_workers=num_workers, + + **kwargs, + ) + return inference_pipeline(data_path_and_name_and_type, raw_inputs) + + +def inference_modelscope( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + output_dir: Optional[str] = None, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + export_mode = False + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + export_mode = param_dict.get("export_mode", False) + else: + hotword_list_or_file = None + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + batch_size = 1 + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + hotword_list_or_file=hotword_list_or_file, + ) + if export_mode: + speech2text = Speech2TextExport(**speech2text_kwargs) + else: + speech2text = Speech2Text(**speech2text_kwargs) + + def _forward( + data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + + hotword_list_or_file = None + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + if 'hotword' in kwargs: + hotword_list_or_file = kwargs['hotword'] + if hotword_list_or_file is not None or 'hotword' in kwargs: + speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) + + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + if param_dict is not None: + use_timestamp = param_dict.get('use_timestamp', True) + else: + use_timestamp = True + + forward_time_total = 0.0 + length_total = 0.0 + finish_count = 0 + file_count = 1 + cache = None + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + writer = DatadirWriter(output_path) + else: + writer = None + if param_dict is not None and "cache" in param_dict: + cache = param_dict["cache"] + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} + logging.info("decoding, utt_id: {}".format(keys)) + # N-best list of (text, token, token_int, hyp_object) + + time_beg = time.time() + results = speech2text(cache=cache, **batch) + if len(results) < 1: + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest + time_end = time.time() + forward_time = time_end - time_beg + lfr_factor = results[0][-1] + length = results[0][-2] + forward_time_total += forward_time + length_total += length + rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, + 100 * forward_time / ( + length * lfr_factor)) + logging.info(rtf_cur) + + for batch_id in range(_bs): + result = [results[batch_id][:-2]] + + key = keys[batch_id] + for n, result in zip(range(1, nbest + 1), result): + text, token, token_int, hyp = result[0], result[1], result[2], result[3] + time_stamp = None if len(result) < 5 else result[4] + # Create a directory: outdir/{n}best_recog + if writer is not None: + ibest_writer = writer[f"{n}best_recog"] + + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) + ibest_writer["rtf"][key] = rtf_cur + + if text is not None: + if use_timestamp and time_stamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) + else: + postprocessed_result = postprocess_utils.sentence_postprocess(token) + time_stamp_postprocessed = "" + if len(postprocessed_result) == 3: + text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ + postprocessed_result[1], \ + postprocessed_result[2] + else: + text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] + item = {'key': key, 'value': text_postprocessed} + if time_stamp_postprocessed != "": + item['time_stamp'] = time_stamp_postprocessed + asr_result_list.append(item) + finish_count += 1 + # asr_utils.print_progress(finish_count / file_count) + if writer is not None: + ibest_writer["text"][key] = text_postprocessed + + logging.info("decoding, utt: {}, predictions: {}".format(key, text)) + rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, + forward_time_total, + 100 * forward_time_total / ( + length_total * lfr_factor)) + logging.info(rtf_avg) + if writer is not None: + ibest_writer["rtf"]["rtf_avf"] = rtf_avg + return asr_result_list + + return _forward + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + parser.add_argument( + "--hotword", + type=str_or_none, + default=None, + help="hotword file path or hotwords seperated by space" + ) + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--asr_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--asr_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global cmvn file", + ) + group.add_argument( + "--lm_train_config", + type=str, + help="LM training configuration", + ) + group.add_argument( + "--lm_file", + type=str, + help="LM parameter file", + ) + group.add_argument( + "--word_lm_train_config", + type=str, + help="Word LM training configuration", + ) + group.add_argument( + "--word_lm_file", + type=str, + help="Word LM parameter file", + ) + group.add_argument( + "--ngram_file", + type=str, + help="N-gram parameter file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths." + "If maxlenratio<0.0, its absolute value is interpreted" + "as a constant max output length", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.5, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") + group.add_argument("--streaming", type=str2bool, default=False) + + group.add_argument( + "--frontend_conf", + default=None, + help="", + ) + group.add_argument("--raw_inputs", type=list, default=None) + # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + param_dict = {'hotword': args.hotword} + kwargs = vars(args) + kwargs.pop("config", None) + kwargs['param_dict'] = param_dict + inference(**kwargs) + + +if __name__ == "__main__": + main() + + # from modelscope.pipelines import pipeline + # from modelscope.utils.constant import Tasks + # + # inference_16k_pipline = pipeline( + # task=Tasks.auto_speech_recognition, + # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') + # + # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') + # print(rec_result) + diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index ab03f0b61..011743022 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -947,6 +947,65 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): ) return logp.squeeze(0), state + def forward_chunk( + self, + memory: torch.Tensor, + tgt: torch.Tensor, + cache: dict = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + x = tgt + if cache["decode_fsmn"] is None: + cache_layer_num = len(self.decoders) + if self.decoders2 is not None: + cache_layer_num += len(self.decoders2) + new_cache = [None] * cache_layer_num + else: + new_cache = cache["decode_fsmn"] + for i in range(self.att_layer_num): + decoder = self.decoders[i] + x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, None, memory, None, cache=new_cache[i] + ) + new_cache[i] = c_ret + + if self.num_blocks - self.att_layer_num > 1: + for i in range(self.num_blocks - self.att_layer_num): + j = i + self.att_layer_num + decoder = self.decoders2[i] + x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, None, memory, None, cache=new_cache[j] + ) + new_cache[j] = c_ret + + for decoder in self.decoders3: + + x, tgt_mask, memory, memory_mask, _ = decoder( + x, None, memory, None, cache=None + ) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + cache["decode_fsmn"] = new_cache + return x + def forward_one_step( self, tgt: torch.Tensor, diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 44c9de3af..02f60af22 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -325,6 +325,65 @@ class Paraformer(AbsESPnetModel): return encoder_out, encoder_out_lens + def encode_chunk( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( + feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens + def calc_predictor(self, encoder_out, encoder_out_lens): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( @@ -333,6 +392,11 @@ class Paraformer(AbsESPnetModel): ignore_id=self.ignore_id) return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + def calc_predictor_chunk(self, encoder_out, cache=None): + + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"]) + return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): decoder_outs = self.decoder( @@ -342,6 +406,14 @@ class Paraformer(AbsESPnetModel): decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens + def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): + decoder_outs = self.decoder.forward_chunk( + encoder_out, sematic_embeds, cache["decoder"] + ) + decoder_out = decoder_outs + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out + def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -1459,4 +1531,4 @@ class ContextualParaformer(Paraformer): "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape)) - return var_dict_torch_update \ No newline at end of file + return var_dict_torch_update diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index 0751a1020..57890efe6 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -347,6 +347,48 @@ class SANMEncoder(AbsEncoder): return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + def forward_chunk(self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + cache: dict = None, + ctc: CTC = None, + ): + xs_pad *= self.output_size() ** 0.5 + if self.embed is None: + xs_pad = xs_pad + else: + xs_pad = self.embed.forward_chunk(xs_pad, cache) + + encoder_outs = self.encoders0(xs_pad, None, None, None, None) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + encoder_outs = self.encoders(xs_pad, None, None, None, None) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, None, None, None, None) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), None, None + return xs_pad, ilens, None + def gen_tf2torch_map_dict(self): tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index 561537323..74f3e68a9 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -199,6 +199,63 @@ class CifPredictorV2(nn.Module): return acoustic_embeds, token_num, alphas, cif_peak + def forward_chunk(self, hidden, cache=None): + h = hidden + context = h.transpose(1, 2) + queries = self.pad(context) + output = torch.relu(self.cif_conv1d(queries)) + output = output.transpose(1, 2) + output = self.cif_output(output) + alphas = torch.sigmoid(output) + alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) + + alphas = alphas.squeeze(-1) + mask_chunk_predictor = None + if cache is not None: + mask_chunk_predictor = None + mask_chunk_predictor = torch.zeros_like(alphas) + mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0 + + if mask_chunk_predictor is not None: + alphas = alphas * mask_chunk_predictor + + if cache is not None: + if cache["cif_hidden"] is not None: + hidden = torch.cat((cache["cif_hidden"], hidden), 1) + if cache["cif_alphas"] is not None: + alphas = torch.cat((cache["cif_alphas"], alphas), -1) + + token_num = alphas.sum(-1) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + len_time = alphas.size(-1) + last_fire_place = len_time - 1 + last_fire_remainds = 0.0 + pre_alphas_length = 0 + + mask_chunk_peak_predictor = None + if cache is not None: + mask_chunk_peak_predictor = None + mask_chunk_peak_predictor = torch.zeros_like(cif_peak) + if cache["cif_alphas"] is not None: + pre_alphas_length = cache["cif_alphas"].size(-1) + mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0 + mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0 + + + if mask_chunk_peak_predictor is not None: + cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1) + + for i in range(len_time): + if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold: + last_fire_place = len_time - 1 - i + last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold + break + last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device) + cache["cif_hidden"] = hidden[:, last_fire_place:, :] + cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1) + token_num_int = token_num.floor().type(torch.int32).item() + return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): b, t, d = hidden.size() tail_threshold = self.tail_threshold diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index 627700524..31d5a8775 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -347,15 +347,17 @@ class MultiHeadedAttentionSANM(nn.Module): mask = torch.reshape(mask, (b, -1, 1)) if mask_shfit_chunk is not None: mask = mask * mask_shfit_chunk + inputs = inputs * mask - inputs = inputs * mask x = inputs.transpose(1, 2) x = self.pad_fn(x) x = self.fsmn_block(x) x = x.transpose(1, 2) x += inputs x = self.dropout(x) - return x * mask + if mask is not None: + x = x * mask + return x def forward_qkv(self, x): """Transform query, key and value. @@ -505,7 +507,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module): # print("in fsmn, cache is None, x", x.size()) x = self.pad_fn(x) - if not self.training and t <= 1: + if not self.training: cache = x else: # print("in fsmn, cache is not None, x", x.size()) @@ -513,7 +515,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module): # if t < self.kernel_size: # x = self.pad_fn(x) x = torch.cat((cache[:, :, 1:], x), dim=2) - x = x[:, :, -self.kernel_size:] + x = x[:, :, -(self.kernel_size+t-1):] # print("in fsmn, cache is not None, x_cat", x.size()) cache = x x = self.fsmn_block(x) diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index b61a61a88..e4f9bff03 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -405,4 +405,13 @@ class SinusoidalPositionEncoder(torch.nn.Module): positions = torch.arange(1, timesteps+1)[None, :] position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) - return x + position_encoding \ No newline at end of file + return x + position_encoding + + def forward_chunk(self, x, cache=None): + start_idx = 0 + batch_size, timesteps, input_dim = x.size() + if cache is not None: + start_idx = cache["start_idx"] + positions = torch.arange(1, timesteps+start_idx+1)[None, :] + position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) + return x + position_encoding[:, start_idx: start_idx + timesteps] From 62f88ea941e0c7904954e9936cf8fc462fecbcd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=81=E8=BF=B7?= Date: Wed, 15 Mar 2023 14:57:09 +0800 Subject: [PATCH 2/4] fix decoder cache bug --- funasr/models/decoder/sanm_decoder.py | 48 ++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index 011743022..3bfcffc3f 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -90,6 +90,47 @@ class DecoderLayerSANM(nn.Module): tgt = self.norm1(tgt) tgt = self.feed_forward(tgt) + x = tgt + if self.self_attn: + if self.normalize_before: + tgt = self.norm2(tgt) + x, _ = self.self_attn(tgt, tgt_mask) + x = residual + self.dropout(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm3(x) + + x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) + + + return x, tgt_mask, memory, memory_mask, cache + + def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + # tgt = self.dropout(tgt) + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt = self.feed_forward(tgt) + x = tgt if self.self_attn: if self.normalize_before: @@ -109,7 +150,6 @@ class DecoderLayerSANM(nn.Module): return x, tgt_mask, memory, memory_mask, cache - class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): """ author: Speech Lab, Alibaba Group, China @@ -980,7 +1020,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): new_cache = cache["decode_fsmn"] for i in range(self.att_layer_num): decoder = self.decoders[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, None, memory, None, cache=new_cache[i] ) new_cache[i] = c_ret @@ -989,14 +1029,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): for i in range(self.num_blocks - self.att_layer_num): j = i + self.att_layer_num decoder = self.decoders2[i] - x, tgt_mask, memory, memory_mask, c_ret = decoder( + x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( x, None, memory, None, cache=new_cache[j] ) new_cache[j] = c_ret for decoder in self.decoders3: - x, tgt_mask, memory, memory_mask, _ = decoder( + x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( x, None, memory, None, cache=None ) if self.normalize_before: From 49ded3a686daa816a58376fa67b8df782ffba312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Thu, 16 Mar 2023 19:18:03 +0800 Subject: [PATCH 3/4] modify diar pipeline --- funasr/bin/sond_inference.py | 30 +++++++++++++++++++++++++----- funasr/models/e2e_diar_sond.py | 26 +++++++++++++++++--------- funasr/tasks/diar.py | 17 ++++++++++++++++- 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py index ab6d26f45..e2e29c701 100755 --- a/funasr/bin/sond_inference.py +++ b/funasr/bin/sond_inference.py @@ -54,7 +54,7 @@ class Speech2Diarization: self, diar_train_config: Union[Path, str] = None, diar_model_file: Union[Path, str] = None, - device: str = "cpu", + device: Union[str, torch.device] = "cpu", batch_size: int = 1, dtype: str = "float32", streaming: bool = False, @@ -114,9 +114,19 @@ class Speech2Diarization: # little-endian order: lower bit first return (np.array(list(b)[::-1]) == '1').astype(dtype) - return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) + # process oov + seq = np.array([int(x) for x in seq]) + new_seq = [] + for i, x in enumerate(seq): + if x < 2 ** vec_dim: + new_seq.append(x) + else: + idx_list = np.where(seq < 2 ** vec_dim)[0] + idx = np.abs(idx_list - i).argmin() + new_seq.append(seq[idx_list[idx]]) + return np.row_stack([int2vec(x, vec_dim) for x in new_seq]) - def post_processing(self, raw_logits: torch.Tensor, spk_num: int): + def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"): logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T # upsampling outputs to match inputs ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio @@ -127,8 +137,14 @@ class Speech2Diarization: ).squeeze(1).long() logits_idx = logits_idx[0].tolist() pse_labels = [self.token_list[x] for x in logits_idx] + if output_format == "pse_labels": + return pse_labels, None + multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers multi_labels = self.smooth_multi_labels(multi_labels) + if output_format == "binary_labels": + return multi_labels, None + spk_list = ["spk{}".format(i + 1) for i in range(spk_num)] spk_turns = self.calc_spk_turns(multi_labels, spk_list) results = OrderedDict() @@ -149,6 +165,7 @@ class Speech2Diarization: self, speech: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], + output_format: str = "speaker_turn" ): """Inference @@ -178,7 +195,7 @@ class Speech2Diarization: batch = to_device(batch, device=self.device) logits = self.diar_model.prediction_forward(**batch) - results, pse_labels = self.post_processing(logits, profile.shape[1]) + results, pse_labels = self.post_processing(logits, profile.shape[1], output_format) return results, pse_labels @@ -367,7 +384,7 @@ def inference_modelscope( pse_label_writer = open("{}/labels.txt".format(output_path), "w") logging.info("Start to diarize...") result_list = [] - for keys, batch in loader: + for idx, (keys, batch) in enumerate(loader): assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) @@ -385,6 +402,9 @@ def inference_modelscope( pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels))) pse_label_writer.flush() + if idx % 100 == 0: + logging.info("Processing {:5d}: {}".format(idx, key)) + if output_path is not None: output_writer.close() pse_label_writer.close() diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py index 258d78080..de669f2ee 100644 --- a/funasr/models/e2e_diar_sond.py +++ b/funasr/models/e2e_diar_sond.py @@ -59,7 +59,8 @@ class DiarSondModel(AbsESPnetModel): normalize_speech_speaker: bool = False, ignore_id: int = -1, speaker_discrimination_loss_weight: float = 1.0, - inter_score_loss_weight: float = 0.0 + inter_score_loss_weight: float = 0.0, + inputs_type: str = "raw", ): assert check_argument_types() @@ -86,14 +87,12 @@ class DiarSondModel(AbsESPnetModel): ) self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) self.pse_embedding = self.generate_pse_embedding() - # self.register_buffer("pse_embedding", pse_embedding) self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() - # self.register_buffer("power_weight", power_weight) self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() - # self.register_buffer("int_token_arr", int_token_arr) self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight self.inter_score_loss_weight = inter_score_loss_weight self.forward_steps = 0 + self.inputs_type = inputs_type def generate_pse_embedding(self): embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) @@ -125,9 +124,14 @@ class DiarSondModel(AbsESPnetModel): binary_labels: (Batch, frames, max_spk_num) binary_labels_lengths: (Batch,) """ - assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) + assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape) batch_size = speech.shape[0] self.forward_steps = self.forward_steps + 1 + if self.pse_embedding.device != speech.device: + self.pse_embedding = self.pse_embedding.to(speech.device) + self.power_weight = self.power_weight.to(speech.device) + self.int_token_arr = self.int_token_arr.to(speech.device) + # 1. Network forward pred, inter_outputs = self.prediction_forward( speech, speech_lengths, @@ -149,9 +153,13 @@ class DiarSondModel(AbsESPnetModel): # the sequence length of 'pred' might be slightly less than the # length of 'spk_labels'. Here we force them to be equal. length_diff_tolerance = 2 - length_diff = pse_labels.shape[1] - pred.shape[1] - if 0 < length_diff <= length_diff_tolerance: - pse_labels = pse_labels[:, 0: pred.shape[1]] + length_diff = abs(pse_labels.shape[1] - pred.shape[1]) + if length_diff <= length_diff_tolerance: + min_len = min(pred.shape[1], pse_labels.shape[1]) + pse_labels = pse_labels[:, :min_len] + pred = pred[:, :min_len] + cd_score = cd_score[:, :min_len] + ci_score = ci_score[:, :min_len] loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) @@ -299,7 +307,7 @@ class DiarSondModel(AbsESPnetModel): speech: torch.Tensor, speech_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.encoder is not None: + if self.encoder is not None and self.inputs_type == "raw": speech, speech_lengths = self.encode(speech, speech_lengths) speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1]) speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float() diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index e699dccb0..bf3ae414b 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -499,7 +499,7 @@ class DiarTask(AbsTask): config_file: Union[Path, str] = None, model_file: Union[Path, str] = None, cmvn_file: Union[Path, str] = None, - device: str = "cpu", + device: Union[str, torch.device] = "cpu", ): """Build model from the files. @@ -554,6 +554,7 @@ class DiarTask(AbsTask): model.load_state_dict(model_dict) else: model_dict = torch.load(model_file, map_location=device) + model_dict = cls.fileter_model_dict(model_dict, model.state_dict()) model.load_state_dict(model_dict) if model_name_pth is not None and not os.path.exists(model_name_pth): torch.save(model_dict, model_name_pth) @@ -561,6 +562,20 @@ class DiarTask(AbsTask): return model, args + @classmethod + def fileter_model_dict(cls, src_dict: dict, dest_dict: dict): + from collections import OrderedDict + new_dict = OrderedDict() + for key, value in src_dict.items(): + if key in dest_dict: + new_dict[key] = value + else: + logging.info("{} is no longer needed in this model.".format(key)) + for key, value in dest_dict.items(): + if key not in new_dict: + logging.warning("{} is missed in checkpoint.".format(key)) + return new_dict + @classmethod def convert_tf2torch( cls, From 0ac06c029edb57e2dcacd64da2a05869a2f7364d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Thu, 16 Mar 2023 19:24:15 +0800 Subject: [PATCH 4/4] fixbug path_name_type_list can [[any,str,str],[any,str,str]] --- funasr/datasets/iterable_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py index fa0adeb28..3798280b6 100644 --- a/funasr/datasets/iterable_dataset.py +++ b/funasr/datasets/iterable_dataset.py @@ -8,6 +8,7 @@ from typing import Dict from typing import Iterator from typing import Tuple from typing import Union +from typing import List import kaldiio import numpy as np @@ -127,7 +128,7 @@ class IterableESPnetDataset(IterableDataset): non_iterable_list = [] self.path_name_type_list = [] - if not isinstance(path_name_type_list[0], Tuple): + if not isinstance(path_name_type_list[0], (Tuple, List)): path = path_name_type_list[0] name = path_name_type_list[1] _type = path_name_type_list[2]