From 9d6aad2a442f96e3094f076f998766697eecd6bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 11 May 2023 19:06:18 +0800 Subject: [PATCH] paraformer vad punc --- .../demo.py | 8 +- funasr/bin/asr_inference_launch.py | 33 +- funasr/bin/asr_inference_paraformer.py | 325 +++++-- funasr/bin/asr_inference_paraformer_vad.py | 549 ----------- .../bin/asr_inference_paraformer_vad_punc.py | 891 ------------------ funasr/bin/asr_inference_uniasr_vad.py | 695 -------------- modelscope | 1 + 7 files changed, 272 insertions(+), 2230 deletions(-) delete mode 100644 funasr/bin/asr_inference_paraformer_vad.py delete mode 100644 funasr/bin/asr_inference_paraformer_vad_punc.py delete mode 100644 funasr/bin/asr_inference_uniasr_vad.py create mode 120000 modelscope diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py index edc3a050e..1fa6b27b1 100644 --- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py +++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py @@ -3,7 +3,9 @@ from modelscope.utils.constant import Tasks inference_pipeline = pipeline( task=Tasks.auto_speech_recognition, - model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') - -rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') + model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', + vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch', +) +audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav' +rec_result = inference_pipeline(audio_in=audio_in) print(rec_result) diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 9a1ffe5ee..db91ed29e 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -254,27 +254,15 @@ def inference_launch(**kwargs): elif mode == "uniasr": from funasr.bin.asr_inference_uniasr import inference_modelscope return inference_modelscope(**kwargs) - elif mode == "uniasr_vad": - from funasr.bin.asr_inference_uniasr_vad import inference_modelscope - return inference_modelscope(**kwargs) elif mode == "paraformer": from funasr.bin.asr_inference_paraformer import inference_modelscope return inference_modelscope(**kwargs) elif mode == "paraformer_streaming": from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope return inference_modelscope(**kwargs) - elif mode == "paraformer_vad": - from funasr.bin.asr_inference_paraformer_vad import inference_modelscope - return inference_modelscope(**kwargs) - elif mode == "paraformer_punc": - logging.info("Unknown decoding mode: {}".format(mode)) - return None - elif mode == "paraformer_vad_punc": - from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope - return inference_modelscope(**kwargs) - elif mode == "vad": - from funasr.bin.vad_inference import inference_modelscope - return inference_modelscope(**kwargs) + elif mode.startswith("paraformer_vad"): + from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc + return inference_modelscope_vad_punc(**kwargs) elif mode == "mfcca": from funasr.bin.asr_inference_mfcca import inference_modelscope return inference_modelscope(**kwargs) @@ -301,14 +289,13 @@ def inference_launch_funasr(**kwargs): from funasr.bin.asr_inference_uniasr import inference return inference(**kwargs) elif mode == "paraformer": - from funasr.bin.asr_inference_paraformer import inference - return inference(**kwargs) - elif mode == "paraformer_vad_punc": - from funasr.bin.asr_inference_paraformer_vad_punc import inference - return inference(**kwargs) - elif mode == "vad": - from funasr.bin.vad_inference import inference - return inference(**kwargs) + from funasr.bin.asr_inference_paraformer import inference_modelscope + inference_pipeline = inference_modelscope(**kwargs) + return inference_pipeline(kwargs["data_path_and_name_and_type"]) + elif mode.startswith("paraformer_vad"): + from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc + inference_pipeline = inference_modelscope_vad_punc(**kwargs) + return inference_pipeline(kwargs["data_path_and_name_and_type"]) elif mode == "mfcca": from funasr.bin.asr_inference_mfcca import inference_modelscope return inference_modelscope(**kwargs) diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index ce6e8f9e3..2a33bdf23 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -45,7 +45,9 @@ from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.bin.tp_inference import SpeechText2Timestamp - +from funasr.bin.vad_inference import Speech2VadSegment +from funasr.bin.punctuation_infer import Text2Punc +from funasr.utils.vad_utils import slice_padding_fbank class Speech2Text: """Speech2Text class @@ -299,7 +301,7 @@ class Speech2Text: vad_offset=begin_time) results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor)) else: - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + results.append((text, token, token_int, hyp, [], enc_len_batch_total, lfr_factor)) # assert check_return_type(results) return results @@ -359,72 +361,6 @@ class Speech2Text: return hotword_list -def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - timestamp_infer_config: Union[Path, str] = None, - timestamp_model_file: Union[Path, str] = None, - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - def inference_modelscope( maxlenratio: float, @@ -606,7 +542,7 @@ def inference_modelscope( key = keys[batch_id] for n, result in zip(range(1, nbest + 1), result): text, token, token_int, hyp = result[0], result[1], result[2], result[3] - timestamp = None if len(result) < 5 else result[4] + timestamp = result[4] if len(result[4]) > 0 else None # conduct timestamp prediction here # timestamp inference requires token length # thus following inference cannot be conducted in batch @@ -658,6 +594,257 @@ def inference_modelscope( return _forward +def inference_modelscope_vad_punc( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + vad_infer_config: Optional[str] = None, + vad_model_file: Optional[str] = None, + vad_cmvn_file: Optional[str] = None, + time_stamp_writer: bool = True, + punc_infer_config: Optional[str] = None, + punc_model_file: Optional[str] = None, + outputs_dict: Optional[bool] = True, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + else: + hotword_list_or_file = None + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2vadsegment + speech2vadsegment_kwargs = dict( + vad_infer_config=vad_infer_config, + vad_model_file=vad_model_file, + vad_cmvn_file=vad_cmvn_file, + device=device, + dtype=dtype, + ) + # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs)) + speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs) + + # 3. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + hotword_list_or_file=hotword_list_or_file, + ) + speech2text = Speech2Text(**speech2text_kwargs) + text2punc = None + if punc_model_file is not None: + text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype) + + if output_dir is not None: + writer = DatadirWriter(output_dir) + ibest_writer = writer[f"1best_recog"] + ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list) + + def _forward(data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + + hotword_list_or_file = None + if param_dict is not None: + hotword_list_or_file = param_dict.get('hotword') + + if 'hotword' in kwargs: + hotword_list_or_file = kwargs['hotword'] + + if speech2text.hotword_list is None: + speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) + + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + batch_size=1, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), + collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + if param_dict is not None: + use_timestamp = param_dict.get('use_timestamp', True) + else: + use_timestamp = True + + finish_count = 0 + file_count = 1 + lfr_factor = 6 + # 7 .Start for-loop + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + writer = None + if output_path is not None: + writer = DatadirWriter(output_path) + ibest_writer = writer[f"1best_recog"] + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + + vad_results = speech2vadsegment(**batch) + _, vadsegments = vad_results[0], vad_results[1][0] + + speech, speech_lengths = batch["speech"], batch["speech_lengths"] + + n = len(vadsegments) + data_with_index = [(vadsegments[i], i) for i in range(n)] + sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) + results_sorted = [] + for j, beg_idx in enumerate(range(0, n, batch_size)): + end_idx = min(n, beg_idx + batch_size) + speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx]) + + batch = {"speech": speech_j, "speech_lengths": speech_lengths_j} + batch = to_device(batch, device=device) + results = speech2text(**batch) + + if len(results) < 1: + results = [["", [], [], [], [], [], []]] + results_sorted.extend(results) + restored_data = [0] * n + for j in range(n): + index = sorted_data[j][1] + restored_data[index] = results_sorted[j] + result = ["", [], [], [], [], [], []] + for j in range(n): + result[0] += restored_data[j][0] + result[1] += restored_data[j][1] + result[2] += restored_data[j][2] + if len(restored_data[j][4]) > 0: + for t in restored_data[j][4]: + t[0] += vadsegments[j][0] + t[1] += vadsegments[j][0] + result[4] += restored_data[j][4] + # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))] + + key = keys[0] + # result = result_segments[0] + text, token, token_int = result[0], result[1], result[2] + time_stamp = result[4] if len(result[4]) > 0 else None + + if use_timestamp and time_stamp is not None: + postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) + else: + postprocessed_result = postprocess_utils.sentence_postprocess(token) + text_postprocessed = "" + time_stamp_postprocessed = "" + text_postprocessed_punc = postprocessed_result + if len(postprocessed_result) == 3: + text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ + postprocessed_result[1], \ + postprocessed_result[2] + else: + text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] + + text_postprocessed_punc = text_postprocessed + punc_id_list = [] + if len(word_lists) > 0 and text2punc is not None: + text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) + + item = {'key': key, 'value': text_postprocessed_punc} + if text_postprocessed != "": + item['text_postprocessed'] = text_postprocessed + if time_stamp_postprocessed != "": + item['time_stamp'] = time_stamp_postprocessed + + item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed) + + asr_result_list.append(item) + finish_count += 1 + # asr_utils.print_progress(finish_count / file_count) + if writer is not None: + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["vad"][key] = "{}".format(vadsegments) + ibest_writer["text"][key] = " ".join(word_lists) + ibest_writer["text_with_punc"][key] = text_postprocessed_punc + if time_stamp_postprocessed is not None: + ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) + + logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc)) + return asr_result_list + + return _forward + + def get_parser(): parser = config_argparse.ArgumentParser( description="ASR Decoding", diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py deleted file mode 100644 index 977dc9bb3..000000000 --- a/funasr/bin/asr_inference_paraformer_vad.py +++ /dev/null @@ -1,549 +0,0 @@ -#!/usr/bin/env python3 - -import json -import argparse -import logging -import sys -import time -from pathlib import Path -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict -from typing import Any -from typing import List -import math -import numpy as np -import torch -from typeguard import check_argument_types - -from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch -from funasr.modules.beam_search.beam_search import Hypothesis -from funasr.modules.scorers.ctc import CTCPrefixScorer -from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.asr import ASRTaskParaformer as ASRTask -from funasr.tasks.lm import LMTask -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.token_id_converter import TokenIDConverter -from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.tasks.vad import VADTask -from funasr.bin.punctuation_infer import Text2Punc -from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text -from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment - - -def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - vad_infer_config: Optional[str] = None, - vad_model_file: Optional[str] = None, - vad_cmvn_file: Optional[str] = None, - time_stamp_writer: bool = False, - punc_infer_config: Optional[str] = None, - punc_model_file: Optional[str] = None, - **kwargs, -): - - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - vad_infer_config=vad_infer_config, - vad_model_file=vad_model_file, - vad_cmvn_file=vad_cmvn_file, - time_stamp_writer=time_stamp_writer, - punc_infer_config=punc_infer_config, - punc_model_file=punc_model_file, - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - -def inference_modelscope( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - # data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - vad_infer_config: Optional[str] = None, - vad_model_file: Optional[str] = None, - vad_cmvn_file: Optional[str] = None, - time_stamp_writer: bool = True, - punc_infer_config: Optional[str] = None, - punc_model_file: Optional[str] = None, - outputs_dict: Optional[bool] = True, - param_dict: dict = None, - **kwargs, -): - assert check_argument_types() - ncpu = kwargs.get("ncpu", 1) - torch.set_num_threads(ncpu) - - if word_lm_train_config is not None: - raise NotImplementedError("Word LM is not implemented") - if ngpu > 1: - raise NotImplementedError("only single GPU decoding is supported") - - logging.basicConfig( - level=log_level, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - else: - hotword_list_or_file = None - - if ngpu >= 1 and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # 1. Set random-seed - set_all_random_seed(seed) - - # 2. Build speech2vadsegment - speech2vadsegment_kwargs = dict( - vad_infer_config=vad_infer_config, - vad_model_file=vad_model_file, - vad_cmvn_file=vad_cmvn_file, - device=device, - dtype=dtype, - ) - # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs)) - speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs) - - # 3. Build speech2text - speech2text_kwargs = dict( - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - bpemodel=bpemodel, - device=device, - maxlenratio=maxlenratio, - minlenratio=minlenratio, - dtype=dtype, - beam_size=beam_size, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - ngram_weight=ngram_weight, - penalty=penalty, - nbest=nbest, - hotword_list_or_file=hotword_list_or_file, - ) - speech2text = Speech2Text(**speech2text_kwargs) - text2punc = None - if punc_model_file is not None: - text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype) - - if output_dir is not None: - writer = DatadirWriter(output_dir) - ibest_writer = writer[f"1best_recog"] - ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list) - - def _forward(data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, - **kwargs, - ): - - hotword_list_or_file = None - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - - if 'hotword' in kwargs: - hotword_list_or_file = kwargs['hotword'] - - if speech2text.hotword_list is None: - speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) - - # 3. Build data-iterator - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - batch_size=1, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), - collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) - - if param_dict is not None: - use_timestamp = param_dict.get('use_timestamp', True) - else: - use_timestamp = True - - finish_count = 0 - file_count = 1 - lfr_factor = 6 - # 7 .Start for-loop - asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - writer = None - if output_path is not None: - writer = DatadirWriter(output_path) - ibest_writer = writer[f"1best_recog"] - - for keys, batch in loader: - assert isinstance(batch, dict), type(batch) - assert all(isinstance(s, str) for s in keys), keys - _bs = len(next(iter(batch.values()))) - assert len(keys) == _bs, f"{len(keys)} != {_bs}" - - vad_results = speech2vadsegment(**batch) - fbanks, vadsegments = vad_results[0], vad_results[1] - for i, segments in enumerate(vadsegments): - result_segments = [["", [], [], ]] - for j, segment_idx in enumerate(segments): - bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10) - segment = fbanks[:, bed_idx:end_idx, :].to(device) - speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device) - batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0], - "end_time": vadsegments[i][j][1]} - results = speech2text(**batch) - if len(results) < 1: - continue - - result_cur = [results[0][:-2]] - if j == 0: - result_segments = result_cur - else: - result_segments = [[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]] - - key = keys[0] - result = result_segments[0] - text, token, token_int = result[0], result[1], result[2] - time_stamp = None if len(result) < 4 else result[3] - - if use_timestamp and time_stamp is not None: - postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) - else: - postprocessed_result = postprocess_utils.sentence_postprocess(token) - text_postprocessed = "" - time_stamp_postprocessed = "" - text_postprocessed_punc = postprocessed_result - if len(postprocessed_result) == 3: - text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ - postprocessed_result[1], \ - postprocessed_result[2] - else: - text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] - text_postprocessed_punc = text_postprocessed - if len(word_lists) > 0 and text2punc is not None: - text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) - - - item = {'key': key, 'value': text_postprocessed_punc} - if text_postprocessed != "": - item['text_postprocessed'] = text_postprocessed - if time_stamp_postprocessed != "": - item['time_stamp'] = time_stamp_postprocessed - - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["vad"][key] = "{}".format(vadsegments) - ibest_writer["text"][key] = " ".join(word_lists) - ibest_writer["text_with_punc"][key] = text_postprocessed_punc - if time_stamp_postprocessed is not None: - ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) - - logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc)) - - - return asr_result_list - return _forward - -def get_parser(): - parser = config_argparse.ArgumentParser( - description="ASR Decoding", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. - parser.add_argument( - "--log_level", - type=lambda x: x.upper(), - default="INFO", - choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), - help="The verbose level of logging", - ) - - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument( - "--ngpu", - type=int, - default=0, - help="The number of gpus. 0 indicates CPU mode", - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - parser.add_argument( - "--dtype", - default="float32", - choices=["float16", "float32", "float64"], - help="Data type", - ) - parser.add_argument( - "--num_workers", - type=int, - default=1, - help="The number of workers used for DataLoader", - ) - - group = parser.add_argument_group("Input data related") - group.add_argument( - "--data_path_and_name_and_type", - type=str2triple_str, - required=False, - action="append", - ) - group.add_argument("--key_file", type=str_or_none) - group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) - - group = parser.add_argument_group("The model configuration related") - group.add_argument( - "--asr_train_config", - type=str, - help="ASR training configuration", - ) - group.add_argument( - "--asr_model_file", - type=str, - help="ASR model parameter file", - ) - group.add_argument( - "--cmvn_file", - type=str, - help="Global cmvn file", - ) - group.add_argument( - "--lm_train_config", - type=str, - help="LM training configuration", - ) - group.add_argument( - "--lm_file", - type=str, - help="LM parameter file", - ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) - group.add_argument( - "--model_tag", - type=str, - help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", - ) - - group = parser.add_argument_group("Beam-search related") - group.add_argument( - "--batch_size", - type=int, - default=1, - help="The batch size for inference", - ) - group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.5, - help="CTC weight in joint decoding", - ) - group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - group.add_argument("--time_stamp_writer", type=str2bool, default=False) - - group.add_argument( - "--frontend_conf", - default=None, - help="", - ) - group.add_argument("--raw_inputs", type=list, default=None) - # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) - - group = parser.add_argument_group("Text converter related") - group.add_argument( - "--token_type", - type=str_or_none, - default=None, - choices=["char", "bpe", None], - help="The token type for ASR model. " - "If not given, refers from the training args", - ) - group.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model path of sentencepiece. " - "If not given, refers from the training args", - ) - group.add_argument( - "--vad_infer_config", - type=str, - help="VAD infer configuration", - ) - group.add_argument( - "--vad_model_file", - type=str, - help="VAD model parameter file", - ) - group.add_argument( - "--vad_cmvn_file", - type=str, - help="vad, Global cmvn file", - ) - group.add_argument( - "--punc_infer_config", - type=str, - help="VAD infer configuration", - ) - group.add_argument( - "--punc_model_file", - type=str, - help="VAD model parameter file", - ) - return parser - - -def main(cmd=None): - print(get_commandline_args(), file=sys.stderr) - parser = get_parser() - args = parser.parse_args(cmd) - kwargs = vars(args) - kwargs.pop("config", None) - inference(**kwargs) - - -if __name__ == "__main__": - main() diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py deleted file mode 100644 index edaad379b..000000000 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ /dev/null @@ -1,891 +0,0 @@ -#!/usr/bin/env python3 - -import json -import argparse -import logging -from re import L -import sys -import time -import os -import codecs -import tempfile -import requests -from pathlib import Path -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict -from typing import Any -from typing import List -import math -import copy -import numpy as np -import torch -from typeguard import check_argument_types - -from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch -from funasr.modules.beam_search.beam_search import Hypothesis -from funasr.modules.scorers.ctc import CTCPrefixScorer -from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.asr import ASRTaskParaformer as ASRTask -from funasr.tasks.lm import LMTask -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.token_id_converter import TokenIDConverter -from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.tasks.vad import VADTask -from funasr.bin.vad_inference import Speech2VadSegment -from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard -from funasr.bin.punctuation_infer import Text2Punc -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer -from funasr.utils.vad_utils import slice_padding_fbank -from funasr.bin.asr_inference_paraformer import Speech2Text -# class Speech2Text: -# """Speech2Text class -# -# Examples: -# >>> import soundfile -# >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") -# >>> audio, rate = soundfile.read("speech.wav") -# >>> speech2text(audio) -# [(text, token, token_int, hypothesis object), ...] -# -# """ -# -# def __init__( -# self, -# asr_train_config: Union[Path, str] = None, -# asr_model_file: Union[Path, str] = None, -# cmvn_file: Union[Path, str] = None, -# lm_train_config: Union[Path, str] = None, -# lm_file: Union[Path, str] = None, -# token_type: str = None, -# bpemodel: str = None, -# device: str = "cpu", -# maxlenratio: float = 0.0, -# minlenratio: float = 0.0, -# dtype: str = "float32", -# beam_size: int = 20, -# ctc_weight: float = 0.5, -# lm_weight: float = 1.0, -# ngram_weight: float = 0.9, -# penalty: float = 0.0, -# nbest: int = 1, -# frontend_conf: dict = None, -# hotword_list_or_file: str = None, -# **kwargs, -# ): -# assert check_argument_types() -# -# # 1. Build ASR model -# scorers = {} -# asr_model, asr_train_args = ASRTask.build_model_from_file( -# asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device -# ) -# frontend = None -# if asr_model.frontend is not None and asr_train_args.frontend_conf is not None: -# frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) -# -# # logging.info("asr_model: {}".format(asr_model)) -# # logging.info("asr_train_args: {}".format(asr_train_args)) -# asr_model.to(dtype=getattr(torch, dtype)).eval() -# -# if asr_model.ctc != None: -# ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) -# scorers.update( -# ctc=ctc -# ) -# token_list = asr_model.token_list -# scorers.update( -# length_bonus=LengthBonus(len(token_list)), -# ) -# -# # 2. Build Language model -# if lm_train_config is not None: -# lm, lm_train_args = LMTask.build_model_from_file( -# lm_train_config, lm_file, device -# ) -# scorers["lm"] = lm.lm -# -# # 3. Build ngram model -# # ngram is not supported now -# ngram = None -# scorers["ngram"] = ngram -# -# # 4. Build BeamSearch object -# # transducer is not supported now -# beam_search_transducer = None -# -# weights = dict( -# decoder=1.0 - ctc_weight, -# ctc=ctc_weight, -# lm=lm_weight, -# ngram=ngram_weight, -# length_bonus=penalty, -# ) -# beam_search = BeamSearch( -# beam_size=beam_size, -# weights=weights, -# scorers=scorers, -# sos=asr_model.sos, -# eos=asr_model.eos, -# vocab_size=len(token_list), -# token_list=token_list, -# pre_beam_score_key=None if ctc_weight == 1.0 else "full", -# ) -# -# beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() -# for scorer in scorers.values(): -# if isinstance(scorer, torch.nn.Module): -# scorer.to(device=device, dtype=getattr(torch, dtype)).eval() -# -# logging.info(f"Decoding device={device}, dtype={dtype}") -# -# # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text -# if token_type is None: -# token_type = asr_train_args.token_type -# if bpemodel is None: -# bpemodel = asr_train_args.bpemodel -# -# if token_type is None: -# tokenizer = None -# elif token_type == "bpe": -# if bpemodel is not None: -# tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) -# else: -# tokenizer = None -# else: -# tokenizer = build_tokenizer(token_type=token_type) -# converter = TokenIDConverter(token_list=token_list) -# logging.info(f"Text tokenizer: {tokenizer}") -# -# self.asr_model = asr_model -# self.asr_train_args = asr_train_args -# self.converter = converter -# self.tokenizer = tokenizer -# -# # 6. [Optional] Build hotword list from str, local file or url -# self.hotword_list = None -# self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) -# -# is_use_lm = lm_weight != 0.0 and lm_file is not None -# if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: -# beam_search = None -# self.beam_search = beam_search -# logging.info(f"Beam_search: {self.beam_search}") -# self.beam_search_transducer = beam_search_transducer -# self.maxlenratio = maxlenratio -# self.minlenratio = minlenratio -# self.device = device -# self.dtype = dtype -# self.nbest = nbest -# self.frontend = frontend -# self.encoder_downsampling_factor = 1 -# if asr_train_args.encoder_conf["input_layer"] == "conv2d": -# self.encoder_downsampling_factor = 4 -# -# @torch.no_grad() -# def __call__( -# self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, -# begin_time: int = 0, end_time: int = None, -# ): -# """Inference -# -# Args: -# speech: Input speech data -# Returns: -# text, token, token_int, hyp -# -# """ -# assert check_argument_types() -# -# # Input as audio signal -# if isinstance(speech, np.ndarray): -# speech = torch.tensor(speech) -# -# if self.frontend is not None: -# feats, feats_len = self.frontend.forward(speech, speech_lengths) -# # fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths) -# # feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths) -# feats = to_device(feats, device=self.device) -# feats_len = feats_len.int() -# self.asr_model.frontend = None -# else: -# feats = speech -# feats_len = speech_lengths -# lfr_factor = max(1, (feats.size()[-1] // 80) - 1) -# batch = {"speech": feats, "speech_lengths": feats_len} -# -# # a. To device -# batch = to_device(batch, device=self.device) -# -# # b. Forward Encoder -# enc, enc_len = self.asr_model.encode(**batch) -# if isinstance(enc, tuple): -# enc = enc[0] -# # assert len(enc) == 1, len(enc) -# enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor -# -# predictor_outs = self.asr_model.calc_predictor(enc, enc_len) -# pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ -# predictor_outs[2], predictor_outs[3] -# pre_token_length = pre_token_length.round().long() -# if torch.max(pre_token_length) < 1: -# return [] -# -# if not isinstance(self.asr_model, ContextualParaformer): -# if self.hotword_list: -# logging.warning("Hotword is given but asr model is not a ContextualParaformer.") -# decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) -# decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] -# else: -# decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) -# decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] -# -# if isinstance(self.asr_model, BiCifParaformer): -# _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len, -# pre_token_length) # test no bias cif2 -# -# results = [] -# b, n, d = decoder_out.size() -# for i in range(b): -# x = enc[i, :enc_len[i], :] -# am_scores = decoder_out[i, :pre_token_length[i], :] -# if self.beam_search is not None: -# nbest_hyps = self.beam_search( -# x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio -# ) -# -# nbest_hyps = nbest_hyps[: self.nbest] -# else: -# yseq = am_scores.argmax(dim=-1) -# score = am_scores.max(dim=-1)[0] -# score = torch.sum(score, dim=-1) -# # pad with mask tokens to ensure compatibility with sos/eos tokens -# yseq = torch.tensor( -# [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device -# ) -# nbest_hyps = [Hypothesis(yseq=yseq, score=score)] -# -# for hyp in nbest_hyps: -# assert isinstance(hyp, (Hypothesis)), type(hyp) -# -# # remove sos/eos and get results -# last_pos = -1 -# if isinstance(hyp.yseq, list): -# token_int = hyp.yseq[1:last_pos] -# else: -# token_int = hyp.yseq[1:last_pos].tolist() -# -# # remove blank symbol id, which is assumed to be 0 -# token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) -# if len(token_int) == 0: -# continue -# -# # Change integer-ids to tokens -# token = self.converter.ids2tokens(token_int) -# -# if self.tokenizer is not None: -# text = self.tokenizer.tokens2text(token) -# else: -# text = None -# -# if isinstance(self.asr_model, BiCifParaformer): -# _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], -# us_peaks[i], -# copy.copy(token), -# vad_offset=begin_time) -# results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) -# else: -# results.append((text, token, token_int, enc_len_batch_total, lfr_factor)) -# -# # assert check_return_type(results) -# return results -# -# def generate_hotwords_list(self, hotword_list_or_file): -# # for None -# if hotword_list_or_file is None: -# hotword_list = None -# # for local txt inputs -# elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): -# logging.info("Attempting to parse hotwords from local txt...") -# hotword_list = [] -# hotword_str_list = [] -# with codecs.open(hotword_list_or_file, 'r') as fin: -# for line in fin.readlines(): -# hw = line.strip() -# hotword_str_list.append(hw) -# hotword_list.append(self.converter.tokens2ids([i for i in hw])) -# hotword_list.append([self.asr_model.sos]) -# hotword_str_list.append('') -# logging.info("Initialized hotword list from file: {}, hotword list: {}." -# .format(hotword_list_or_file, hotword_str_list)) -# # for url, download and generate txt -# elif hotword_list_or_file.startswith('http'): -# logging.info("Attempting to parse hotwords from url...") -# work_dir = tempfile.TemporaryDirectory().name -# if not os.path.exists(work_dir): -# os.makedirs(work_dir) -# text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) -# local_file = requests.get(hotword_list_or_file) -# open(text_file_path, "wb").write(local_file.content) -# hotword_list_or_file = text_file_path -# hotword_list = [] -# hotword_str_list = [] -# with codecs.open(hotword_list_or_file, 'r') as fin: -# for line in fin.readlines(): -# hw = line.strip() -# hotword_str_list.append(hw) -# hotword_list.append(self.converter.tokens2ids([i for i in hw])) -# hotword_list.append([self.asr_model.sos]) -# hotword_str_list.append('') -# logging.info("Initialized hotword list from file: {}, hotword list: {}." -# .format(hotword_list_or_file, hotword_str_list)) -# # for text str input -# elif not hotword_list_or_file.endswith('.txt'): -# logging.info("Attempting to parse hotwords as str...") -# hotword_list = [] -# hotword_str_list = [] -# for hw in hotword_list_or_file.strip().split(): -# hotword_str_list.append(hw) -# hotword_list.append(self.converter.tokens2ids([i for i in hw])) -# hotword_list.append([self.asr_model.sos]) -# hotword_str_list.append('') -# logging.info("Hotword list: {}.".format(hotword_str_list)) -# else: -# hotword_list = None -# return hotword_list - - -def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - vad_infer_config: Optional[str] = None, - vad_model_file: Optional[str] = None, - vad_cmvn_file: Optional[str] = None, - time_stamp_writer: bool = False, - punc_infer_config: Optional[str] = None, - punc_model_file: Optional[str] = None, - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - vad_infer_config=vad_infer_config, - vad_model_file=vad_model_file, - vad_cmvn_file=vad_cmvn_file, - time_stamp_writer=time_stamp_writer, - punc_infer_config=punc_infer_config, - punc_model_file=punc_model_file, - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - - -def inference_modelscope( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - # data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - vad_infer_config: Optional[str] = None, - vad_model_file: Optional[str] = None, - vad_cmvn_file: Optional[str] = None, - time_stamp_writer: bool = True, - punc_infer_config: Optional[str] = None, - punc_model_file: Optional[str] = None, - outputs_dict: Optional[bool] = True, - param_dict: dict = None, - **kwargs, -): - assert check_argument_types() - ncpu = kwargs.get("ncpu", 1) - torch.set_num_threads(ncpu) - - if word_lm_train_config is not None: - raise NotImplementedError("Word LM is not implemented") - if ngpu > 1: - raise NotImplementedError("only single GPU decoding is supported") - - logging.basicConfig( - level=log_level, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - else: - hotword_list_or_file = None - - if ngpu >= 1 and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # 1. Set random-seed - set_all_random_seed(seed) - - # 2. Build speech2vadsegment - speech2vadsegment_kwargs = dict( - vad_infer_config=vad_infer_config, - vad_model_file=vad_model_file, - vad_cmvn_file=vad_cmvn_file, - device=device, - dtype=dtype, - ) - # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs)) - speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs) - - # 3. Build speech2text - speech2text_kwargs = dict( - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - bpemodel=bpemodel, - device=device, - maxlenratio=maxlenratio, - minlenratio=minlenratio, - dtype=dtype, - beam_size=beam_size, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - ngram_weight=ngram_weight, - penalty=penalty, - nbest=nbest, - hotword_list_or_file=hotword_list_or_file, - ) - speech2text = Speech2Text(**speech2text_kwargs) - text2punc = None - if punc_model_file is not None: - text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype) - - if output_dir is not None: - writer = DatadirWriter(output_dir) - ibest_writer = writer[f"1best_recog"] - ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list) - - def _forward(data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, - **kwargs, - ): - - hotword_list_or_file = None - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - - if 'hotword' in kwargs: - hotword_list_or_file = kwargs['hotword'] - - if speech2text.hotword_list is None: - speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) - - # 3. Build data-iterator - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - batch_size=1, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False), - collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) - - if param_dict is not None: - use_timestamp = param_dict.get('use_timestamp', True) - else: - use_timestamp = True - - finish_count = 0 - file_count = 1 - lfr_factor = 6 - # 7 .Start for-loop - asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - writer = None - if output_path is not None: - writer = DatadirWriter(output_path) - ibest_writer = writer[f"1best_recog"] - - for keys, batch in loader: - assert isinstance(batch, dict), type(batch) - assert all(isinstance(s, str) for s in keys), keys - _bs = len(next(iter(batch.values()))) - assert len(keys) == _bs, f"{len(keys)} != {_bs}" - - vad_results = speech2vadsegment(**batch) - _, vadsegments = vad_results[0], vad_results[1][0] - - speech, speech_lengths = batch["speech"], batch["speech_lengths"] - - n = len(vadsegments) - data_with_index = [(vadsegments[i], i) for i in range(n)] - sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) - results_sorted = [] - for j, beg_idx in enumerate(range(0, n, batch_size)): - end_idx = min(n, beg_idx + batch_size) - speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx]) - - batch = {"speech": speech_j, "speech_lengths": speech_lengths_j} - batch = to_device(batch, device=device) - results = speech2text(**batch) - - if len(results) < 1: - results = [["", [], [], [], [], [], []]] - results_sorted.extend(results) - restored_data = [0] * n - for j in range(n): - index = sorted_data[j][1] - restored_data[index] = results_sorted[j] - result = ["", [], [], [], [], [], []] - for j in range(n): - result[0] += restored_data[j][0] - result[1] += restored_data[j][1] - result[2] += restored_data[j][2] - for t in restored_data[j][4]: - t[0] += vadsegments[j][0] - t[1] += vadsegments[j][0] - result[4] += restored_data[j][4] - # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))] - - key = keys[0] - # result = result_segments[0] - text, token, token_int = result[0], result[1], result[2] - time_stamp = None if len(result) < 5 else result[4] - - if use_timestamp and time_stamp is not None: - postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) - else: - postprocessed_result = postprocess_utils.sentence_postprocess(token) - text_postprocessed = "" - time_stamp_postprocessed = "" - text_postprocessed_punc = postprocessed_result - if len(postprocessed_result) == 3: - text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ - postprocessed_result[1], \ - postprocessed_result[2] - else: - text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] - - text_postprocessed_punc = text_postprocessed - punc_id_list = [] - if len(word_lists) > 0 and text2punc is not None: - text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) - - item = {'key': key, 'value': text_postprocessed_punc} - if text_postprocessed != "": - item['text_postprocessed'] = text_postprocessed - if time_stamp_postprocessed != "": - item['time_stamp'] = time_stamp_postprocessed - - item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed) - - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["vad"][key] = "{}".format(vadsegments) - ibest_writer["text"][key] = " ".join(word_lists) - ibest_writer["text_with_punc"][key] = text_postprocessed_punc - if time_stamp_postprocessed is not None: - ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) - - logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc)) - return asr_result_list - - return _forward - - -def get_parser(): - parser = config_argparse.ArgumentParser( - description="ASR Decoding", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. - parser.add_argument( - "--log_level", - type=lambda x: x.upper(), - default="INFO", - choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), - help="The verbose level of logging", - ) - - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument( - "--ngpu", - type=int, - default=0, - help="The number of gpus. 0 indicates CPU mode", - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - parser.add_argument( - "--dtype", - default="float32", - choices=["float16", "float32", "float64"], - help="Data type", - ) - parser.add_argument( - "--num_workers", - type=int, - default=1, - help="The number of workers used for DataLoader", - ) - - group = parser.add_argument_group("Input data related") - group.add_argument( - "--data_path_and_name_and_type", - type=str2triple_str, - required=False, - action="append", - ) - group.add_argument("--key_file", type=str_or_none) - group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) - - group = parser.add_argument_group("The model configuration related") - group.add_argument( - "--asr_train_config", - type=str, - help="ASR training configuration", - ) - group.add_argument( - "--asr_model_file", - type=str, - help="ASR model parameter file", - ) - group.add_argument( - "--cmvn_file", - type=str, - help="Global cmvn file", - ) - group.add_argument( - "--lm_train_config", - type=str, - help="LM training configuration", - ) - group.add_argument( - "--lm_file", - type=str, - help="LM parameter file", - ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) - group.add_argument( - "--model_tag", - type=str, - help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", - ) - - group = parser.add_argument_group("Beam-search related") - group.add_argument( - "--batch_size", - type=int, - default=1, - help="The batch size for inference", - ) - group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.5, - help="CTC weight in joint decoding", - ) - group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - group.add_argument("--time_stamp_writer", type=str2bool, default=False) - - group.add_argument( - "--frontend_conf", - default=None, - help="", - ) - group.add_argument("--raw_inputs", type=list, default=None) - # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) - - group = parser.add_argument_group("Text converter related") - group.add_argument( - "--token_type", - type=str_or_none, - default=None, - choices=["char", "bpe", None], - help="The token type for ASR model. " - "If not given, refers from the training args", - ) - group.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model path of sentencepiece. " - "If not given, refers from the training args", - ) - group.add_argument( - "--vad_infer_config", - type=str, - help="VAD infer configuration", - ) - group.add_argument( - "--vad_model_file", - type=str, - help="VAD model parameter file", - ) - group.add_argument( - "--vad_cmvn_file", - type=str, - help="vad, Global cmvn file", - ) - group.add_argument( - "--punc_infer_config", - type=str, - help="VAD infer configuration", - ) - group.add_argument( - "--punc_model_file", - type=str, - help="VAD model parameter file", - ) - return parser - - -def main(cmd=None): - print(get_commandline_args(), file=sys.stderr) - parser = get_parser() - args = parser.parse_args(cmd) - kwargs = vars(args) - kwargs.pop("config", None) - inference(**kwargs) - - -if __name__ == "__main__": - main() diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py deleted file mode 100644 index 52c29b8a9..000000000 --- a/funasr/bin/asr_inference_uniasr_vad.py +++ /dev/null @@ -1,695 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import logging -import sys -from pathlib import Path -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict -from typing import Any - -import numpy as np -import torch -from typeguard import check_argument_types -from typeguard import check_return_type - -from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch -from funasr.modules.beam_search.beam_search import Hypothesis -from funasr.modules.scorers.ctc import CTCPrefixScorer -from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.asr import ASRTaskUniASR as ASRTask -from funasr.tasks.lm import LMTask -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.token_id_converter import TokenIDConverter -from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend - - -header_colors = '\033[95m' -end_colors = '\033[0m' - - -class Speech2Text: - """Speech2Text class - - Examples: - >>> import soundfile - >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") - >>> speech2text(audio) - [(text, token, token_int, hypothesis object), ...] - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - token_num_relax: int = 1, - decoding_ind: int = 0, - decoding_mode: str = "model1", - frontend_conf: dict = None, - **kwargs, - ): - assert check_argument_types() - - # 1. Build ASR model - scorers = {} - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() - if decoding_mode == "model1": - decoder = asr_model.decoder - else: - decoder = asr_model.decoder2 - - if asr_model.ctc != None: - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) - scorers.update( - ctc=ctc - ) - token_list = asr_model.token_list - scorers.update( - decoder=decoder, - length_bonus=LengthBonus(len(token_list)), - ) - - # 2. Build Language model - if lm_train_config is not None: - lm, lm_train_args = LMTask.build_model_from_file( - lm_train_config, lm_file, device - ) - scorers["lm"] = lm.lm - - # 3. Build ngram model - # ngram is not supported now - ngram = None - scorers["ngram"] = ngram - - # 4. Build BeamSearch object - # transducer is not supported now - beam_search_transducer = None - - weights = dict( - decoder=1.0 - ctc_weight, - ctc=ctc_weight, - lm=lm_weight, - ngram=ngram_weight, - length_bonus=penalty, - ) - beam_search = BeamSearch( - beam_size=beam_size, - weights=weights, - scorers=scorers, - sos=asr_model.sos, - eos=asr_model.eos, - vocab_size=len(token_list), - token_list=token_list, - pre_beam_score_key=None if ctc_weight == 1.0 else "full", - ) - - beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() - for scorer in scorers.values(): - if isinstance(scorer, torch.nn.Module): - scorer.to(device=device, dtype=getattr(torch, dtype)).eval() - # logging.info(f"Beam_search: {beam_search}") - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - self.beam_search = beam_search - self.beam_search_transducer = beam_search_transducer - self.maxlenratio = maxlenratio - self.minlenratio = minlenratio - self.device = device - self.dtype = dtype - self.nbest = nbest - self.token_num_relax = token_num_relax - self.decoding_ind = decoding_ind - self.decoding_mode = decoding_mode - self.frontend = frontend - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ) -> List[ - Tuple[ - Optional[str], - List[str], - List[int], - Union[Hypothesis], - ] - ]: - """Inference - - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - feats_raw = feats.clone().to(self.device) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - # b. Forward Encoder - _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind) - if isinstance(enc, tuple): - enc = enc[0] - assert len(enc) == 1, len(enc) - if self.decoding_mode == "model1": - predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len) - else: - enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind) - predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len) - - scama_mask = predictor_outs[4] - pre_token_length = predictor_outs[1] - pre_acoustic_embeds = predictor_outs[0] - maxlen = pre_token_length.sum().item() + self.token_num_relax - minlen = max(0, pre_token_length.sum().item() - self.token_num_relax) - # c. Passed the encoder result and the beam search - nbest_hyps = self.beam_search( - x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio, - minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen), - ) - - nbest_hyps = nbest_hyps[: self.nbest] - - results = [] - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - token = list(filter(lambda x: x != "", token)) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - results.append((text, token, token_int, hyp)) - - assert check_return_type(results) - return results - - -def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - ngram_file: Optional[str] = None, - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - token_num_relax: int = 1, - decoding_ind: int = 0, - decoding_mode: str = "model1", - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - ngram_file=ngram_file, - nbest=nbest, - num_workers=num_workers, - token_num_relax=token_num_relax, - decoding_ind=decoding_ind, - decoding_mode=decoding_mode, - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - - -def inference_modelscope( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - # data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - ngram_file: Optional[str] = None, - cmvn_file: Optional[str] = None, - # raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - token_num_relax: int = 1, - decoding_ind: int = 0, - decoding_mode: str = "model1", - param_dict: dict = None, - **kwargs, -): - assert check_argument_types() - if batch_size > 1: - raise NotImplementedError("batch decoding is not implemented") - if word_lm_train_config is not None: - raise NotImplementedError("Word LM is not implemented") - if ngpu > 1: - raise NotImplementedError("only single GPU decoding is supported") - - logging.basicConfig( - level=log_level, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - - if ngpu >= 1 and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - if param_dict is not None and "decoding_model" in param_dict: - if param_dict["decoding_model"] == "fast": - decoding_ind = 0 - decoding_mode = "model1" - elif param_dict["decoding_model"] == "normal": - decoding_ind = 0 - decoding_mode = "model2" - elif param_dict["decoding_model"] == "offline": - decoding_ind = 1 - decoding_mode = "model2" - else: - raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"])) - - # 1. Set random-seed - set_all_random_seed(seed) - - # 2. Build speech2text - speech2text_kwargs = dict( - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - lm_train_config=lm_train_config, - lm_file=lm_file, - ngram_file=ngram_file, - token_type=token_type, - bpemodel=bpemodel, - device=device, - maxlenratio=maxlenratio, - minlenratio=minlenratio, - dtype=dtype, - beam_size=beam_size, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - ngram_weight=ngram_weight, - penalty=penalty, - nbest=nbest, - streaming=streaming, - token_num_relax=token_num_relax, - decoding_ind=decoding_ind, - decoding_mode=decoding_mode, - ) - speech2text = Speech2Text(**speech2text_kwargs) - - def _forward(data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, - **kwargs, - ): - # 3. Build data-iterator - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - batch_size=batch_size, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) - - finish_count = 0 - file_count = 1 - # 7 .Start for-loop - # FIXME(kamo): The output format should be discussed about - asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - if output_path is not None: - writer = DatadirWriter(output_path) - else: - writer = None - - for keys, batch in loader: - assert isinstance(batch, dict), type(batch) - assert all(isinstance(s, str) for s in keys), keys - _bs = len(next(iter(batch.values()))) - assert len(keys) == _bs, f"{len(keys)} != {_bs}" - #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} - - # N-best list of (text, token, token_int, hyp_object) - try: - results = speech2text(**batch) - except TooShortUttError as e: - logging.warning(f"Utterance {keys} {e}") - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp]] * nbest - - # Only supporting batch_size==1 - key = keys[0] - logging.info(f"Utterance: {key}") - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] - - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - - if text is not None: - text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) - item = {'key': key, 'value': text_postprocessed} - asr_result_list.append(item) - finish_count += 1 - asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = " ".join(word_lists) - return asr_result_list - - return _forward - - - -def get_parser(): - parser = config_argparse.ArgumentParser( - description="ASR Decoding", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. - parser.add_argument( - "--log_level", - type=lambda x: x.upper(), - default="INFO", - choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), - help="The verbose level of logging", - ) - - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument( - "--ngpu", - type=int, - default=0, - help="The number of gpus. 0 indicates CPU mode", - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - parser.add_argument( - "--dtype", - default="float32", - choices=["float16", "float32", "float64"], - help="Data type", - ) - parser.add_argument( - "--num_workers", - type=int, - default=1, - help="The number of workers used for DataLoader", - ) - - group = parser.add_argument_group("Input data related") - group.add_argument( - "--data_path_and_name_and_type", - type=str2triple_str, - required=False, - action="append", - ) - group.add_argument("--raw_inputs", type=list, default=None) - # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) - group.add_argument("--key_file", type=str_or_none) - group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) - - group = parser.add_argument_group("The model configuration related") - group.add_argument( - "--asr_train_config", - type=str, - help="ASR training configuration", - ) - group.add_argument( - "--asr_model_file", - type=str, - help="ASR model parameter file", - ) - group.add_argument( - "--cmvn_file", - type=str, - help="Global cmvn file", - ) - group.add_argument( - "--lm_train_config", - type=str, - help="LM training configuration", - ) - group.add_argument( - "--lm_file", - type=str, - help="LM parameter file", - ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) - group.add_argument( - "--model_tag", - type=str, - help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", - ) - - group = parser.add_argument_group("Beam-search related") - group.add_argument( - "--batch_size", - type=int, - default=1, - help="The batch size for inference", - ) - group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.5, - help="CTC weight in joint decoding", - ) - group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - - group = parser.add_argument_group("Text converter related") - group.add_argument( - "--token_type", - type=str_or_none, - default=None, - choices=["char", "bpe", None], - help="The token type for ASR model. " - "If not given, refers from the training args", - ) - group.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model path of sentencepiece. " - "If not given, refers from the training args", - ) - group.add_argument("--token_num_relax", type=int, default=1, help="") - group.add_argument("--decoding_ind", type=int, default=0, help="") - group.add_argument("--decoding_mode", type=str, default="model1", help="") - group.add_argument( - "--ctc_weight2", - type=float, - default=0.0, - help="CTC weight in joint decoding", - ) - return parser - - -def main(cmd=None): - print(get_commandline_args(), file=sys.stderr) - parser = get_parser() - args = parser.parse_args(cmd) - kwargs = vars(args) - kwargs.pop("config", None) - inference(**kwargs) - - -if __name__ == "__main__": - main() diff --git a/modelscope b/modelscope new file mode 120000 index 000000000..5af854c53 --- /dev/null +++ b/modelscope @@ -0,0 +1 @@ +../MaaS-lib/modelscope \ No newline at end of file