diff --git a/egs_modelscope/common/modelscope_utils/modelscope_infer.sh b/egs_modelscope/common/modelscope_utils/modelscope_infer.sh index 80f0d166b..a0c606f7f 100755 --- a/egs_modelscope/common/modelscope_utils/modelscope_infer.sh +++ b/egs_modelscope/common/modelscope_utils/modelscope_infer.sh @@ -65,6 +65,7 @@ for dset in ${test_sets}; do ${decode_cmd} --max-jobs-run "${inference_nj}" JOB=1:"${inference_nj}" "${_logdir}"/asr_inference.JOB.log \ python -m funasr.bin.modelscope_infer \ --model_name ${model_name} \ + --model_revision ${model_revision} \ --wav_list ${_logdir}/keys.JOB.scp \ --output_file ${_logdir}/text.JOB \ --gpuid_list ${gpuid_list} \ diff --git a/funasr/bin/asr_inference_modelscope.py b/funasr/bin/asr_inference_modelscope.py new file mode 100755 index 000000000..fd9bd6609 --- /dev/null +++ b/funasr/bin/asr_inference_modelscope.py @@ -0,0 +1,687 @@ +#!/usr/bin/env python3 +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import argparse +import logging +import sys +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch +from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim +from funasr.modules.beam_search.beam_search import BeamSearch +from funasr.modules.beam_search.beam_search import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.scorers.scorer_interface import BatchScorerInterface +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.asr import ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +header_colors = '\033[95m' +end_colors = '\033[0m' + +global_asr_language: str = 'zh-cn' +global_sample_rate: Union[int, Dict[Any, int]] = { + 'audio_fs': 16000, + 'model_fs': 16000 +} + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, device + ) + if asr_model.frontend is None and frontend_conf is not None: + frontend = WavFrontend(**frontend_conf) + asr_model.frontend = frontend + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + # TODO(karita): make all scorers batchfied + if batch_size == 1: + non_batch = [ + k + for k, v in beam_search.full_scorers.items() + if not isinstance(v, BatchScorerInterface) + ] + if len(non_batch) == 0: + if streaming: + beam_search.__class__ = BatchBeamSearchOnlineSim + beam_search.set_streaming_config(asr_train_config) + logging.info( + "BatchBeamSearchOnlineSim implementation is selected." + ) + else: + beam_search.__class__ = BatchBeamSearch + logging.info("BatchBeamSearch implementation is selected.") + else: + logging.warning( + f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation." + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray] + ) -> List[ + Tuple[ + Optional[str], + List[str], + List[int], + Union[Hypothesis], + ] + ]: + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + # data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + lfr_factor = max(1, (speech.size()[-1] // 80) - 1) + # lengths: (1,) + lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) + batch = {"speech": speech, "speech_lengths": lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, _ = self.asr_model.encode(**batch) + if isinstance(enc, tuple): + enc = enc[0] + assert len(enc) == 1, len(enc) + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) + + assert check_return_type(results) + return results + + +def inference( + maxlenratio: float, + minlenratio: float, + batch_size: int, + dtype: str, + beam_size: int, + ngpu: int, + seed: int, + ctc_weight: float, + lm_weight: float, + ngram_weight: float, + penalty: float, + nbest: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: list, + audio_lists: Union[List[Any], bytes], + key_file: Optional[str], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + lm_train_config: Optional[str], + lm_file: Optional[str], + word_lm_train_config: Optional[str], + token_type: Optional[str], + bpemodel: Optional[str], + output_dir: Optional[str], + allow_variable_data_keys: bool, + streaming: bool, + frontend_conf: dict = None, + fs: Union[dict, int] = 16000, + **kwargs, +) -> List[Any]: + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + features_type: str = data_path_and_name_and_type[1] + hop_length: int = 160 + sr: int = 16000 + if isinstance(fs, int): + sr = fs + else: + if 'model_fs' in fs and fs['model_fs'] is not None: + sr = fs['model_fs'] + if features_type != 'sound': + frontend_conf = None + if frontend_conf is not None: + if 'hop_length' in frontend_conf: + hop_length = frontend_conf['hop_length'] + + finish_count = 0 + file_count = 1 + if isinstance(audio_lists, bytes): + file_count = 1 + else: + file_count = len(audio_lists) + if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None: + mvn_file = data_path_and_name_and_type[2] + mvn_data = wav_utils.extract_CMVN_featrures(mvn_file) + frontend_conf['mvn_data'] = mvn_data + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + streaming=streaming, + frontend_conf=frontend_conf, + ) + speech2text = Speech2Text(**speech2text_kwargs) + data_path_and_name_and_type_new = [ + audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1] + ] + # 3. Build data-iterator + loader = ASRTask.build_streaming_iterator_modelscope( + data_path_and_name_and_type_new, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + sample_rate=fs + ) + + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + asr_result_list = [] + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + # N-best list of (text, token, token_int, hyp_object) + try: + results = speech2text(**batch) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", [""], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + if text is not None: + text_postprocessed = postprocess_utils.sentence_postprocess(token) + item = {'key': key, 'value': text_postprocessed} + asr_result_list.append(item) + finish_count += 1 + asr_utils.print_progress(finish_count / file_count) + + return asr_result_list + + + +def set_parameters(language: str = None, + sample_rate: Union[int, Dict[Any, int]] = None): + if language is not None: + global global_asr_language + global_asr_language = language + if sample_rate is not None: + global global_sample_rate + global_sample_rate = sample_rate + + +def asr_inference(maxlenratio: float, + minlenratio: float, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + name_and_type: list, + audio_lists: Union[List[Any], bytes], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + nbest: int = 1, + num_workers: int = 1, + log_level: Union[int, str] = 'INFO', + batch_size: int = 1, + dtype: str = 'float32', + seed: int = 0, + key_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + word_lm_file: Optional[str] = None, + ngram_file: Optional[str] = None, + ngram_weight: float = 0.9, + model_tag: Optional[str] = None, + token_type: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + transducer_conf: Optional[dict] = None, + streaming: bool = False, + frontend_conf: dict = None, + fs: Union[dict, int] = None, + lang: Optional[str] = None, + outputdir: Optional[str] = None): + if lang is not None: + global global_asr_language + global_asr_language = lang + if fs is not None: + global global_sample_rate + global_sample_rate = fs + + # force use CPU if data type is bytes + if isinstance(audio_lists, bytes): + num_workers = 0 + ngpu = 0 + + return inference(output_dir=outputdir, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + batch_size=batch_size, + dtype=dtype, + beam_size=beam_size, + ngpu=ngpu, + seed=seed, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + num_workers=num_workers, + log_level=log_level, + data_path_and_name_and_type=name_and_type, + audio_lists=audio_lists, + key_file=key_file, + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + word_lm_train_config=word_lm_train_config, + word_lm_file=word_lm_file, + ngram_file=ngram_file, + model_tag=model_tag, + token_type=token_type, + bpemodel=bpemodel, + allow_variable_data_keys=allow_variable_data_keys, + transducer_conf=transducer_conf, + streaming=streaming, + frontend_conf=frontend_conf) + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument("--audio_lists", type=list, + default=[{'key':'EdevDEWdIYQ_0021', + 'file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--asr_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--asr_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--lm_train_config", + type=str, + help="LM training configuration", + ) + group.add_argument( + "--lm_file", + type=str, + help="LM parameter file", + ) + group.add_argument( + "--word_lm_train_config", + type=str, + help="Word LM training configuration", + ) + group.add_argument( + "--word_lm_file", + type=str, + help="Word LM parameter file", + ) + group.add_argument( + "--ngram_file", + type=str, + help="N-gram parameter file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths." + "If maxlenratio<0.0, its absolute value is interpreted" + "as a constant max output length", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.5, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") + group.add_argument("--streaming", type=str2bool, default=False) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/funasr/bin/asr_inference_paraformer_modelscope.py b/funasr/bin/asr_inference_paraformer_modelscope.py new file mode 100755 index 000000000..d64fe2b25 --- /dev/null +++ b/funasr/bin/asr_inference_paraformer_modelscope.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys +import time +from pathlib import Path +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import List +from typing import Dict + +import numpy as np +import torch +from typeguard import check_argument_types + +from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch +from funasr.modules.beam_search.beam_search import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.asr import ASRTaskParaformer as ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +header_colors = '\033[95m' +end_colors = '\033[0m' + +global_asr_language: str = 'zh-cn' +global_sample_rate: Union[int, Dict[Any, int]] = { + 'audio_fs': 16000, + 'model_fs': 16000 +} + + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + frontend_conf: dict = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, device + ) + if asr_model.frontend is None and frontend_conf is not None: + frontend = WavFrontend(**frontend_conf) + asr_model.frontend = frontend + asr_model.to(dtype=getattr(torch, dtype)).eval() + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray] + ): + """Inference + + Args: + speech: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + # data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + lfr_factor = max(1, (speech.size()[-1] // 80) - 1) + # lengths: (1,) + lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) + batch = {"speech": speech, "speech_lengths": lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, enc_len = self.asr_model.encode(**batch) + if isinstance(enc, tuple): + enc = enc[0] + assert len(enc) == 1, len(enc) + + predictor_outs = self.asr_model.calc_predictor(enc, enc_len) + pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1] + pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device) + decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + + nbest_hyps = self.beam_search( + x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + + results.append((text, token, token_int, hyp, speech.size(1), lfr_factor)) + + # assert check_return_type(results) + return results + + +def inference( + maxlenratio: float, + minlenratio: float, + batch_size: int, + dtype: str, + beam_size: int, + ngpu: int, + seed: int, + ctc_weight: float, + lm_weight: float, + ngram_weight: float, + penalty: float, + nbest: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: list, + audio_lists: Union[List[Any], bytes], + key_file: Optional[str], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + lm_train_config: Optional[str], + lm_file: Optional[str], + word_lm_train_config: Optional[str], + model_tag: Optional[str], + token_type: Optional[str], + bpemodel: Optional[str], + output_dir: Optional[str], + allow_variable_data_keys: bool, + frontend_conf: dict = None, + fs: Union[dict, int] = 16000, + **kwargs, +) -> List[Any]: + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + # data_path_and_name_and_type = data_path_and_name_and_type[0] + features_type: str = data_path_and_name_and_type[1] + hop_length: int = 160 + sr: int = 16000 + if isinstance(fs, int): + sr = fs + else: + if 'model_fs' in fs and fs['model_fs'] is not None: + sr = fs['model_fs'] + if features_type != 'sound': + frontend_conf = None + if frontend_conf is not None: + if 'hop_length' in frontend_conf: + hop_length = frontend_conf['hop_length'] + + finish_count = 0 + file_count = 1 + if isinstance(audio_lists, bytes): + file_count = 1 + else: + file_count = len(audio_lists) + if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None: + mvn_file = data_path_and_name_and_type[2] + mvn_data = wav_utils.extract_CMVN_featrures(mvn_file) + frontend_conf['mvn_data'] = mvn_data + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + frontend_conf=frontend_conf, + ) + speech2text = Speech2Text(**speech2text_kwargs) + + data_path_and_name_and_type_new = [ + audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1] + ] + + # 3. Build data-iterator + loader = ASRTask.build_streaming_iterator_modelscope( + data_path_and_name_and_type_new, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + sample_rate=fs + ) + + forward_time_total = 0.0 + length_total = 0.0 + asr_result_list = [] + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + logging.info("decoding, utt_id: {}".format(keys)) + # N-best list of (text, token, token_int, hyp_object) + + try: + time_beg = time.time() + results = speech2text(**batch) + time_end = time.time() + forward_time = time_end - time_beg + lfr_factor = results[0][-1] + length = results[0][-2] + results = [results[0][:-2]] + forward_time_total += forward_time + length_total += length + logging.info( + "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}". + format(length, forward_time, 100 * forward_time / (length * lfr_factor))) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", [""], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + if text is not None: + text_postprocessed = postprocess_utils.sentence_postprocess(token) + item = {'key': key, 'value': text_postprocessed} + asr_result_list.append(item) + + logging.info("decoding, predictions: {}".format(text)) + finish_count += 1 + asr_utils.print_progress(finish_count / file_count) + + logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}". + format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))) + if features_type == 'sound': + # data format is wav + length_total_seconds = length_total / sr + length_total_bytes = length_total * 2 + else: + # data format is kaldi_ark + length_total_seconds = length_total * hop_length / sr + length_total_bytes = length_total * hop_length * 2 + + logger.info( + header_colors + # noqa: * + 'decoding, feature length total: {}bytes, forward_time total: {:.4f}s, rtf avg: {:.4f}' + .format(length_total_bytes, forward_time_total, forward_time_total / + length_total_seconds) + end_colors) + + return asr_result_list + + +def set_parameters(language: str = None, + sample_rate: Union[int, Dict[Any, int]] = None): + if language is not None: + global global_asr_language + global_asr_language = language + if sample_rate is not None: + global global_sample_rate + global_sample_rate = sample_rate + + +def asr_inference(maxlenratio: float, + minlenratio: float, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + name_and_type: list, + audio_lists: Union[List[Any], bytes], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + nbest: int = 1, + num_workers: int = 1, + log_level: Union[int, str] = 'INFO', + batch_size: int = 1, + dtype: str = 'float32', + seed: int = 0, + key_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + word_lm_file: Optional[str] = None, + ngram_file: Optional[str] = None, + ngram_weight: float = 0.9, + model_tag: Optional[str] = None, + token_type: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + transducer_conf: Optional[dict] = None, + streaming: bool = False, + frontend_conf: dict = None, + fs: Union[dict, int] = None, + lang: Optional[str] = None, + outputdir: Optional[str] = None): + if lang is not None: + global global_asr_language + global_asr_language = lang + if fs is not None: + global global_sample_rate + global_sample_rate = fs + + # force use CPU if data type is bytes + if isinstance(audio_lists, bytes): + num_workers = 0 + ngpu = 0 + + return inference(output_dir=outputdir, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + batch_size=batch_size, + dtype=dtype, + beam_size=beam_size, + ngpu=ngpu, + seed=seed, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + num_workers=num_workers, + log_level=log_level, + data_path_and_name_and_type=name_and_type, + audio_lists=audio_lists, + key_file=key_file, + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + word_lm_train_config=word_lm_train_config, + word_lm_file=word_lm_file, + ngram_file=ngram_file, + model_tag=model_tag, + token_type=token_type, + bpemodel=bpemodel, + allow_variable_data_keys=allow_variable_data_keys, + transducer_conf=transducer_conf, + streaming=streaming, + frontend_conf=frontend_conf) + + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument("--audio_lists", type=list, default=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--asr_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--asr_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--lm_train_config", + type=str, + help="LM training configuration", + ) + group.add_argument( + "--lm_file", + type=str, + help="LM parameter file", + ) + group.add_argument( + "--word_lm_train_config", + type=str, + help="Word LM training configuration", + ) + group.add_argument( + "--word_lm_file", + type=str, + help="Word LM parameter file", + ) + group.add_argument( + "--ngram_file", + type=str, + help="N-gram parameter file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths." + "If maxlenratio<0.0, its absolute value is interpreted" + "as a constant max output length", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.5, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") + group.add_argument("--streaming", type=str2bool, default=False) + + group.add_argument( + "--asr_model_config", + default=None, + help="", + ) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/funasr/bin/modelscope_infer.py b/funasr/bin/modelscope_infer.py index 440c88163..74c2fb7ae 100755 --- a/funasr/bin/modelscope_infer.py +++ b/funasr/bin/modelscope_infer.py @@ -15,6 +15,10 @@ if __name__ == '__main__': type=str, default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", help="model name in modelscope") + parser.add_argument("--model_revision", + type=str, + default="v1.0.3", + help="model revision in modelscope") parser.add_argument("--local_model_path", type=str, default=None, @@ -62,7 +66,8 @@ if __name__ == '__main__': if args.local_model_path is None: inference_pipeline = pipeline( task=Tasks.auto_speech_recognition, - model="damo/{}".format(args.model_name)) + model="damo/{}".format(args.model_name), + model_revision=args.model_revision) else: inference_pipeline = pipeline( task=Tasks.auto_speech_recognition, diff --git a/funasr/datasets/iterable_dataset_modelscope.py b/funasr/datasets/iterable_dataset_modelscope.py new file mode 100644 index 000000000..860492c5d --- /dev/null +++ b/funasr/datasets/iterable_dataset_modelscope.py @@ -0,0 +1,349 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. +"""Iterable dataset module.""" +import copy +from io import StringIO +from pathlib import Path +from typing import Callable, Collection, Dict, Iterator, Tuple, Union + +import kaldiio +import numpy as np +import soundfile +import torch +from funasr.datasets.dataset import ESPnetDataset +from torch.utils.data.dataset import IterableDataset +from typeguard import check_argument_types + +from funasr.utils import wav_utils + + +def load_kaldi(input): + retval = kaldiio.load_mat(input) + if isinstance(retval, tuple): + assert len(retval) == 2, len(retval) + if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray): + # sound scp case + rate, array = retval + elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray): + # Extended ark format case + array, rate = retval + else: + raise RuntimeError( + f'Unexpected type: {type(retval[0])}, {type(retval[1])}') + + # Multichannel wave fie + # array: (NSample, Channel) or (Nsample) + + else: + # Normal ark case + assert isinstance(retval, np.ndarray), type(retval) + array = retval + return array + + +DATA_TYPES = { + 'sound': + lambda x: soundfile.read(x)[0], + 'kaldi_ark': + load_kaldi, + 'npy': + np.load, + 'text_int': + lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '), + 'csv_int': + lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','), + 'text_float': + lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' ' + ), + 'csv_float': + lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=',' + ), + 'text': + lambda x: x, +} + + +class IterableESPnetDatasetModelScope(IterableDataset): + """Pytorch Dataset class for ESPNet. + + Examples: + >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'), + ... ('token_int', 'output', 'text_int')], + ... ) + >>> for uid, data in dataset: + ... data + {'input': per_utt_array, 'output': per_utt_array} + """ + def __init__(self, + path_name_type_list: Collection[Tuple[any, str, str]], + preprocess: Callable[[str, Dict[str, np.ndarray]], + Dict[str, np.ndarray]] = None, + float_dtype: str = 'float32', + int_dtype: str = 'long', + key_file: str = None, + sample_rate: Union[dict, int] = 16000): + assert check_argument_types() + if len(path_name_type_list) == 0: + raise ValueError( + '1 or more elements are required for "path_name_type_list"') + + self.preprocess = preprocess + + self.float_dtype = float_dtype + self.int_dtype = int_dtype + self.key_file = key_file + self.sample_rate = sample_rate + + self.debug_info = {} + non_iterable_list = [] + self.path_name_type_list = [] + + path_list = path_name_type_list[0] + name = path_name_type_list[1] + _type = path_name_type_list[2] + if name in self.debug_info: + raise RuntimeError(f'"{name}" is duplicated for data-key') + self.debug_info[name] = path_list, _type + # for path, name, _type in path_name_type_list: + for path in path_list: + self.path_name_type_list.append((path, name, _type)) + + if len(non_iterable_list) != 0: + # Some types doesn't support iterable mode + self.non_iterable_dataset = ESPnetDataset( + path_name_type_list=non_iterable_list, + preprocess=preprocess, + float_dtype=float_dtype, + int_dtype=int_dtype, + ) + else: + self.non_iterable_dataset = None + + self.apply_utt2category = False + + def has_name(self, name) -> bool: + return name in self.debug_info + + def names(self) -> Tuple[str, ...]: + return tuple(self.debug_info) + + def __repr__(self): + _mes = self.__class__.__name__ + _mes += '(' + for name, (path, _type) in self.debug_info.items(): + _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' + _mes += f'\n preprocess: {self.preprocess})' + return _mes + + def __iter__( + self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]: + torch.set_printoptions(profile='default') + count = len(self.path_name_type_list) + for idx in range(count): + # 2. Load the entry from each line and create a dict + data = {} + # 2.a. Load data streamingly + + # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav + value = self.path_name_type_list[idx][0]['file'] + uid = self.path_name_type_list[idx][0]['key'] + # name: speech + name = self.path_name_type_list[idx][1] + _type = self.path_name_type_list[idx][2] + func = DATA_TYPES[_type] + array = func(value) + + # 2.b. audio resample + if _type == 'sound': + audio_sr: int = 16000 + model_sr: int = 16000 + if isinstance(self.sample_rate, int): + model_sr = self.sample_rate + else: + if 'audio_sr' in self.sample_rate: + audio_sr = self.sample_rate['audio_sr'] + if 'model_sr' in self.sample_rate: + model_sr = self.sample_rate['model_sr'] + array = wav_utils.torch_resample(array, audio_sr, model_sr) + + # array: [ 1.25122070e-03 ... ] + data[name] = array + + # 3. [Option] Apply preprocessing + # e.g. espnet2.train.preprocessor:CommonPreprocessor + if self.preprocess is not None: + data = self.preprocess(uid, data) + # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])} + + # 4. Force data-precision + for name in data: + # value is np.ndarray data + value = data[name] + if not isinstance(value, np.ndarray): + raise RuntimeError( + f'All values must be converted to np.ndarray object ' + f'by preprocessing, but "{name}" is still {type(value)}.' + ) + + # Cast to desired type + if value.dtype.kind == 'f': + value = value.astype(self.float_dtype) + elif value.dtype.kind == 'i': + value = value.astype(self.int_dtype) + else: + raise NotImplementedError( + f'Not supported dtype: {value.dtype}') + data[name] = value + + yield uid, data + + if count == 0: + raise RuntimeError('No iteration') + + +class IterableESPnetBytesModelScope(IterableDataset): + """Pytorch audio bytes class for ESPNet. + + Examples: + >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'), + ... ('token_int', 'output', 'text_int')], + ... ) + >>> for uid, data in dataset: + ... data + {'input': per_utt_array, 'output': per_utt_array} + """ + def __init__(self, + path_name_type_list: Collection[Tuple[any, str, str]], + preprocess: Callable[[str, Dict[str, np.ndarray]], + Dict[str, np.ndarray]] = None, + float_dtype: str = 'float32', + int_dtype: str = 'long', + key_file: str = None, + sample_rate: Union[dict, int] = 16000): + assert check_argument_types() + if len(path_name_type_list) == 0: + raise ValueError( + '1 or more elements are required for "path_name_type_list"') + + self.preprocess = preprocess + + self.float_dtype = float_dtype + self.int_dtype = int_dtype + self.key_file = key_file + self.sample_rate = sample_rate + + self.debug_info = {} + non_iterable_list = [] + self.path_name_type_list = [] + + audio_data = path_name_type_list[0] + name = path_name_type_list[1] + _type = path_name_type_list[2] + if name in self.debug_info: + raise RuntimeError(f'"{name}" is duplicated for data-key') + self.debug_info[name] = audio_data, _type + self.path_name_type_list.append((audio_data, name, _type)) + + if len(non_iterable_list) != 0: + # Some types doesn't support iterable mode + self.non_iterable_dataset = ESPnetDataset( + path_name_type_list=non_iterable_list, + preprocess=preprocess, + float_dtype=float_dtype, + int_dtype=int_dtype, + ) + else: + self.non_iterable_dataset = None + + self.apply_utt2category = False + + if float_dtype == 'float32': + self.np_dtype = np.float32 + + def has_name(self, name) -> bool: + return name in self.debug_info + + def names(self) -> Tuple[str, ...]: + return tuple(self.debug_info) + + def __repr__(self): + _mes = self.__class__.__name__ + _mes += '(' + for name, (path, _type) in self.debug_info.items(): + _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' + _mes += f'\n preprocess: {self.preprocess})' + return _mes + + def __iter__( + self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]: + + torch.set_printoptions(profile='default') + # 2. Load the entry from each line and create a dict + data = {} + # 2.a. Load data streamingly + + value = self.path_name_type_list[0][0] + uid = 'pcm_data' + # name: speech + name = self.path_name_type_list[0][1] + _type = self.path_name_type_list[0][2] + func = DATA_TYPES[_type] + # array: [ 1.25122070e-03 ... ] + # data[name] = np.frombuffer(value, dtype=self.np_dtype) + + # 2.b. byte(PCM16) to float32 + middle_data = np.frombuffer(value, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, + dtype=self.np_dtype) + + # 2.c. audio resample + if _type == 'sound': + audio_sr: int = 16000 + model_sr: int = 16000 + if isinstance(self.sample_rate, int): + model_sr = self.sample_rate + else: + if 'audio_sr' in self.sample_rate: + audio_sr = self.sample_rate['audio_sr'] + if 'model_sr' in self.sample_rate: + model_sr = self.sample_rate['model_sr'] + array = wav_utils.torch_resample(array, audio_sr, model_sr) + + data[name] = array + + # 3. [Option] Apply preprocessing + # e.g. espnet2.train.preprocessor:CommonPreprocessor + if self.preprocess is not None: + data = self.preprocess(uid, data) + # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])} + + # 4. Force data-precision + for name in data: + # value is np.ndarray data + value = data[name] + if not isinstance(value, np.ndarray): + raise RuntimeError( + f'All values must be converted to np.ndarray object ' + f'by preprocessing, but "{name}" is still {type(value)}.') + + # Cast to desired type + if value.dtype.kind == 'f': + value = value.astype(self.float_dtype) + elif value.dtype.kind == 'i': + value = value.astype(self.int_dtype) + else: + raise NotImplementedError( + f'Not supported dtype: {value.dtype}') + data[name] = value + + yield uid, data diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 5ea28f31a..89f7cf09d 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -330,9 +330,10 @@ class Paraformer(AbsESPnetModel): def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): - decoder_out, _ = self.decoder( + decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens ) + decoder_out = decoder_outs[0] decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens @@ -553,7 +554,6 @@ class ParaformerBert(Paraformer): postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, - joint_network: Optional[torch.nn.Module], ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, @@ -590,7 +590,6 @@ class ParaformerBert(Paraformer): postencoder=postencoder, decoder=decoder, ctc=ctc, - joint_network=joint_network, ctc_weight=ctc_weight, interctc_weight=interctc_weight, ignore_id=ignore_id, diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py new file mode 100644 index 000000000..c0b28ff94 --- /dev/null +++ b/funasr/models/frontend/wav_frontend.py @@ -0,0 +1,155 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. + +import copy +from typing import Optional, Tuple, Union + +import humanfriendly +import numpy as np +import torch +import torchaudio.compliance.kaldi as kaldi +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.layers.log_mel import LogMel +from funasr.layers.stft import Stft +from funasr.utils.get_default_kwargs import get_default_kwargs +from funasr.modules.frontends.frontend import Frontend +from typeguard import check_argument_types + + +def apply_cmvn(inputs, mvn): # noqa + """ + Apply CMVN with mvn data + """ + + device = inputs.device + dtype = inputs.dtype + frame, dim = inputs.shape + + meams = np.tile(mvn[0:1, :dim], (frame, 1)) + vars = np.tile(mvn[1:2, :dim], (frame, 1)) + inputs += torch.from_numpy(meams).type(dtype).to(device) + inputs *= torch.from_numpy(vars).type(dtype).to(device) + + return inputs.type(torch.float32) + + +def apply_lfr(inputs, lfr_m, lfr_n): + LFR_inputs = [] + T = inputs.shape[0] + T_lfr = int(np.ceil(T / lfr_n)) + left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1) + inputs = torch.vstack((left_padding, inputs)) + T = T + (lfr_m - 1) // 2 + for i in range(T_lfr): + if lfr_m <= T - i * lfr_n: + LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1)) + else: # process last LFR frame + num_padding = lfr_m - (T - i * lfr_n) + frame = (inputs[i * lfr_n:]).view(-1) + for _ in range(num_padding): + frame = torch.hstack((frame, inputs[-1])) + LFR_inputs.append(frame) + LFR_outputs = torch.vstack(LFR_inputs) + return LFR_outputs.type(torch.float32) + + +class WavFrontend(AbsFrontend): + """Conventional frontend structure for ASR. + """ + def __init__( + self, + fs: Union[int, str] = 16000, + n_fft: int = 512, + win_length: int = 400, + hop_length: int = 160, + window: Optional[str] = 'hamming', + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: int = None, + fmax: int = None, + lfr_m: int = 1, + lfr_n: int = 1, + htk: bool = False, + mvn_data=None, + frontend_conf: Optional[dict] = get_default_kwargs(Frontend), + apply_stft: bool = True, + ): + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + + # Deepcopy (In general, dict shouldn't be used as default arg) + frontend_conf = copy.deepcopy(frontend_conf) + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.fs = fs + self.mvn_data = mvn_data + self.lfr_m = lfr_m + self.lfr_n = lfr_n + + if apply_stft: + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=center, + window=window, + normalized=normalized, + onesided=onesided, + ) + else: + self.stft = None + self.apply_stft = apply_stft + + if frontend_conf is not None: + self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) + else: + self.frontend = None + + self.logmel = LogMel( + fs=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + ) + self.n_mels = n_mels + self.frontend_type = 'default' + + def output_size(self) -> int: + return self.n_mels + + def forward( + self, input: torch.Tensor, + input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + sample_frequency = self.fs + num_mel_bins = self.n_mels + frame_length = self.win_length * 1000 / sample_frequency + frame_shift = self.hop_length * 1000 / sample_frequency + + waveform = input * (1 << 15) + + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=1.0, + energy_floor=0.0, + window_type=self.window, + sample_frequency=sample_frequency) + if self.lfr_m != 1 or self.lfr_n != 1: + mat = apply_lfr(mat, self.lfr_m, self.lfr_n) + if self.mvn_data is not None: + mat = apply_cmvn(mat, self.mvn_data) + + input_feats = mat[None, :] + feats_lens = torch.randn(1) + feats_lens.fill_(input_feats.shape[1]) + + return input_feats, feats_lens diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index cf60eafd1..ea41c6c17 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -4,7 +4,7 @@ from torch import nn from funasr.modules.nets_utils import make_pad_mask class CifPredictor(nn.Module): - def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0): + def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45): super(CifPredictor, self).__init__() self.pad = nn.ConstantPad1d((l_order, r_order), 0) diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 5ea78c349..d7164233c 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -38,6 +38,7 @@ from funasr.datasets.dataset import AbsDataset from funasr.datasets.dataset import DATA_TYPES from funasr.datasets.dataset import ESPnetDataset from funasr.datasets.iterable_dataset import IterableESPnetDataset +from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope from funasr.iterators.abs_iter_factory import AbsIterFactory from funasr.iterators.chunk_iter_factory import ChunkIterFactory from funasr.iterators.multiple_iter_factory import MultipleIterFactory @@ -1026,7 +1027,7 @@ class AbsTask(ABC): @classmethod def check_task_requirements( cls, - dataset: Union[AbsDataset, IterableESPnetDataset], + dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope], allow_variable_data_keys: bool, train: bool, inference: bool = False, @@ -1748,6 +1749,64 @@ class AbsTask(ABC): **kwargs, ) + @classmethod + def build_streaming_iterator_modelscope( + cls, + data_path_and_name_and_type, + preprocess_fn, + collate_fn, + key_file: str = None, + batch_size: int = 1, + dtype: str = np.float32, + num_workers: int = 1, + allow_variable_data_keys: bool = False, + ngpu: int = 0, + inference: bool = False, + sample_rate: Union[dict, int] = 16000 + ) -> DataLoader: + """Build DataLoader using iterable dataset""" + assert check_argument_types() + # For backward compatibility for pytorch DataLoader + if collate_fn is not None: + kwargs = dict(collate_fn=collate_fn) + else: + kwargs = {} + + audio_data = data_path_and_name_and_type[0] + if isinstance(audio_data, bytes): + dataset = IterableESPnetBytesModelScope( + data_path_and_name_and_type, + float_dtype=dtype, + preprocess=preprocess_fn, + key_file=key_file, + sample_rate=sample_rate + ) + else: + dataset = IterableESPnetDatasetModelScope( + data_path_and_name_and_type, + float_dtype=dtype, + preprocess=preprocess_fn, + key_file=key_file, + sample_rate=sample_rate + ) + + if dataset.apply_utt2category: + kwargs.update(batch_size=1) + else: + kwargs.update(batch_size=batch_size) + + cls.check_task_requirements(dataset, + allow_variable_data_keys, + train=False, + inference=inference) + + return DataLoader( + dataset=dataset, + pin_memory=ngpu > 0, + num_workers=num_workers, + **kwargs, + ) + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ @classmethod def build_model_from_file( diff --git a/funasr/utils/asr_env_checking.py b/funasr/utils/asr_env_checking.py new file mode 100644 index 000000000..c393ee529 --- /dev/null +++ b/funasr/utils/asr_env_checking.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import ssl + +import nltk + +# mkdir nltk_data dir if not exist +try: + nltk.data.find('.') +except LookupError: + dir_list = nltk.data.path + for dir_item in dir_list: + if not os.path.exists(dir_item): + os.mkdir(dir_item) + if os.path.exists(dir_item): + break + +# download one package if nltk_data not exist +try: + nltk.data.find('.') +except: # noqa: * + try: + _create_unverified_https_context = ssl._create_unverified_context + except AttributeError: + pass + else: + ssl._create_default_https_context = _create_unverified_https_context + + nltk.download('cmudict', halt_on_error=False, raise_on_error=True) + +# deploy taggers/averaged_perceptron_tagger +try: + nltk.data.find('taggers/averaged_perceptron_tagger') +except: # noqa: * + data_dir = nltk.data.find('.') + target_dir = os.path.join(data_dir, 'taggers') + if not os.path.exists(target_dir): + os.mkdir(target_dir) + src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages', + 'averaged_perceptron_tagger.zip') + shutil.copyfile(src_file, + os.path.join(target_dir, 'averaged_perceptron_tagger.zip')) + shutil._unpack_zipfile( + os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir) + +# deploy corpora/cmudict +try: + nltk.data.find('corpora/cmudict') +except: # noqa: * + data_dir = nltk.data.find('.') + target_dir = os.path.join(data_dir, 'corpora') + if not os.path.exists(target_dir): + os.mkdir(target_dir) + src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages', + 'cmudict.zip') + shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip')) + shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir) + +try: + nltk.data.find('taggers/averaged_perceptron_tagger') +except: # noqa: * + try: + _create_unverified_https_context = ssl._create_unverified_context + except AttributeError: + pass + else: + ssl._create_default_https_context = _create_unverified_https_context + + nltk.download('averaged_perceptron_tagger', + halt_on_error=False, + raise_on_error=True) + +try: + nltk.data.find('corpora/cmudict') +except: # noqa: * + try: + _create_unverified_https_context = ssl._create_unverified_context + except AttributeError: + pass + else: + ssl._create_default_https_context = _create_unverified_https_context + + nltk.download('cmudict', halt_on_error=False, raise_on_error=True) diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py new file mode 100644 index 000000000..4258f05aa --- /dev/null +++ b/funasr/utils/asr_utils.py @@ -0,0 +1,327 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import struct +from typing import Any, Dict, List, Union + +import librosa +import numpy as np +import pkg_resources +from modelscope.utils.logger import get_logger + +logger = get_logger() + +green_color = '\033[1;32m' +red_color = '\033[0;31;40m' +yellow_color = '\033[0;33;40m' +end_color = '\033[0m' + +global_asr_language = 'zh-cn' + + +def get_version(): + return float(pkg_resources.get_distribution('easyasr').version) + + +def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str): + r_audio_fs = None + + if audio_format == 'wav': + r_audio_fs = get_sr_from_wav(audio_in) + elif audio_format == 'pcm' and isinstance(audio_in, bytes): + r_audio_fs = get_sr_from_bytes(audio_in) + + return r_audio_fs + + +def type_checking(audio_in: Union[str, bytes], + audio_fs: int = None, + recog_type: str = None, + audio_format: str = None): + r_recog_type = recog_type + r_audio_format = audio_format + r_wav_path = audio_in + + if isinstance(audio_in, str): + assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist' + elif isinstance(audio_in, bytes): + assert len(audio_in) > 0, 'audio in is empty' + r_audio_format = 'pcm' + r_recog_type = 'wav' + + if r_recog_type is None: + # audio_in is wav, recog_type is wav_file + if os.path.isfile(audio_in): + if audio_in.endswith('.wav') or audio_in.endswith('.WAV'): + r_recog_type = 'wav' + r_audio_format = 'wav' + + # recog_type is datasets_file + elif os.path.isdir(audio_in): + dir_name = os.path.basename(audio_in) + if 'test' in dir_name: + r_recog_type = 'test' + elif 'dev' in dir_name: + r_recog_type = 'dev' + elif 'train' in dir_name: + r_recog_type = 'train' + + if r_audio_format is None: + if find_file_by_ends(audio_in, '.ark'): + r_audio_format = 'kaldi_ark' + elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends( + audio_in, '.WAV'): + r_audio_format = 'wav' + elif find_file_by_ends(audio_in, '.records'): + r_audio_format = 'tfrecord' + + if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': + # datasets with kaldi_ark file + r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) + elif r_audio_format == 'tfrecord' and r_recog_type != 'wav': + # datasets with tensorflow records file + r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) + elif r_audio_format == 'wav' and r_recog_type != 'wav': + # datasets with waveform files + r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) + + return r_recog_type, r_audio_format, r_wav_path + + +def get_sr_from_bytes(wav: bytes): + sr = None + data = wav + if len(data) > 44: + try: + header_fields = {} + header_fields['ChunkID'] = str(data[0:4], 'UTF-8') + header_fields['Format'] = str(data[8:12], 'UTF-8') + header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') + if header_fields['ChunkID'] == 'RIFF' and header_fields[ + 'Format'] == 'WAVE' and header_fields[ + 'Subchunk1ID'] == 'fmt ': + header_fields['SampleRate'] = struct.unpack(' List[str]: + dir_files = os.listdir(dir_path) + for file in dir_files: + file_path = os.path.join(dir_path, file) + if os.path.isfile(file_path): + if file_path.endswith('.wav') or file_path.endswith('.WAV'): + wav_list.append(file_path) + elif os.path.isdir(file_path): + recursion_dir_all_wav(wav_list, file_path) + + return wav_list + + +def set_parameters(language: str = None): + if language is not None: + global global_asr_language + global_asr_language = language + + +def compute_wer(hyp_list: List[Any], + ref_list: List[Any], + lang: str = None) -> Dict[str, Any]: + assert len(hyp_list) > 0, 'hyp list is empty' + assert len(ref_list) > 0, 'ref list is empty' + + if lang is not None: + global global_asr_language + global_asr_language = lang + + rst = { + 'Wrd': 0, + 'Corr': 0, + 'Ins': 0, + 'Del': 0, + 'Sub': 0, + 'Snt': 0, + 'Err': 0.0, + 'S.Err': 0.0, + 'wrong_words': 0, + 'wrong_sentences': 0 + } + + for h_item in hyp_list: + for r_item in ref_list: + if h_item['key'] == r_item['key']: + out_item = compute_wer_by_line(h_item['value'], + r_item['value'], + global_asr_language) + rst['Wrd'] += out_item['nwords'] + rst['Corr'] += out_item['cor'] + rst['wrong_words'] += out_item['wrong'] + rst['Ins'] += out_item['ins'] + rst['Del'] += out_item['del'] + rst['Sub'] += out_item['sub'] + rst['Snt'] += 1 + if out_item['wrong'] > 0: + rst['wrong_sentences'] += 1 + print_wrong_sentence(key=h_item['key'], + hyp=h_item['value'], + ref=r_item['value']) + else: + print_correct_sentence(key=h_item['key'], + hyp=h_item['value'], + ref=r_item['value']) + + break + + if rst['Wrd'] > 0: + rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) + if rst['Snt'] > 0: + rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) + + return rst + + +def compute_wer_by_line(hyp: List[str], + ref: List[str], + lang: str = 'zh-cn') -> Dict[str, Any]: + if lang != 'zh-cn': + hyp = hyp.split() + ref = ref.split() + + hyp = list(map(lambda x: x.lower(), hyp)) + ref = list(map(lambda x: x.lower(), ref)) + + len_hyp = len(hyp) + len_ref = len(ref) + + cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) + + ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) + + for i in range(len_hyp + 1): + cost_matrix[i][0] = i + for j in range(len_ref + 1): + cost_matrix[0][j] = j + + for i in range(1, len_hyp + 1): + for j in range(1, len_ref + 1): + if hyp[i - 1] == ref[j - 1]: + cost_matrix[i][j] = cost_matrix[i - 1][j - 1] + else: + substitution = cost_matrix[i - 1][j - 1] + 1 + insertion = cost_matrix[i - 1][j] + 1 + deletion = cost_matrix[i][j - 1] + 1 + + compare_val = [substitution, insertion, deletion] + + min_val = min(compare_val) + operation_idx = compare_val.index(min_val) + 1 + cost_matrix[i][j] = min_val + ops_matrix[i][j] = operation_idx + + match_idx = [] + i = len_hyp + j = len_ref + rst = { + 'nwords': len_ref, + 'cor': 0, + 'wrong': 0, + 'ins': 0, + 'del': 0, + 'sub': 0 + } + while i >= 0 or j >= 0: + i_idx = max(0, i) + j_idx = max(0, j) + + if ops_matrix[i_idx][j_idx] == 0: # correct + if i - 1 >= 0 and j - 1 >= 0: + match_idx.append((j - 1, i - 1)) + rst['cor'] += 1 + + i -= 1 + j -= 1 + + elif ops_matrix[i_idx][j_idx] == 2: # insert + i -= 1 + rst['ins'] += 1 + + elif ops_matrix[i_idx][j_idx] == 3: # delete + j -= 1 + rst['del'] += 1 + + elif ops_matrix[i_idx][j_idx] == 1: # substitute + i -= 1 + j -= 1 + rst['sub'] += 1 + + if i < 0 and j >= 0: + rst['del'] += 1 + elif j < 0 and i >= 0: + rst['ins'] += 1 + + match_idx.reverse() + wrong_cnt = cost_matrix[len_hyp][len_ref] + rst['wrong'] = wrong_cnt + + return rst + + +def print_wrong_sentence(key: str, hyp: str, ref: str): + space = len(key) + print(key + yellow_color + ' ref: ' + ref) + print(' ' * space + red_color + ' hyp: ' + hyp + end_color) + + +def print_correct_sentence(key: str, hyp: str, ref: str): + space = len(key) + print(key + yellow_color + ' ref: ' + ref) + print(' ' * space + green_color + ' hyp: ' + hyp + end_color) + + +def print_progress(percent): + if percent > 1: + percent = 1 + res = int(50 * percent) * '#' + print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='') diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py new file mode 100644 index 000000000..72080ae74 --- /dev/null +++ b/funasr/utils/postprocess_utils.py @@ -0,0 +1,174 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import string +from typing import Any, List, Union + + +def isChinese(ch: str): + if '\u4e00' <= ch <= '\u9fff': + return True + return False + + +def isAllChinese(word: Union[List[Any], str]): + word_lists = [] + table = str.maketrans('', '', string.punctuation) + for i in word: + cur = i.translate(table) + cur = cur.replace(' ', '') + cur = cur.replace('', '') + cur = cur.replace('', '') + word_lists.append(cur) + + if len(word_lists) == 0: + return False + + for ch in word_lists: + if isChinese(ch) is False: + return False + return True + + +def isAllAlpha(word: Union[List[Any], str]): + word_lists = [] + table = str.maketrans('', '', string.punctuation) + for i in word: + cur = i.translate(table) + cur = cur.replace(' ', '') + cur = cur.replace('', '') + cur = cur.replace('', '') + word_lists.append(cur) + + if len(word_lists) == 0: + return False + + for ch in word_lists: + if ch.isalpha() is False: + return False + elif ch.isalpha() is True and isChinese(ch) is True: + return False + + return True + + +def abbr_dispose(words: List[Any]) -> List[Any]: + words_size = len(words) + word_lists = [] + abbr_begin = [] + abbr_end = [] + last_num = -1 + for num in range(words_size): + if num <= last_num: + continue + + if len(words[num]) == 1 and words[num].encode('utf-8').isalpha(): + if num + 1 < words_size and words[ + num + 1] == ' ' and num + 2 < words_size and len( + words[num + + 2]) == 1 and words[num + + 2].encode('utf-8').isalpha(): + # found the begin of abbr + abbr_begin.append(num) + num += 2 + abbr_end.append(num) + # to find the end of abbr + while True: + num += 1 + if num < words_size and words[num] == ' ': + num += 1 + if num < words_size and len( + words[num]) == 1 and words[num].encode( + 'utf-8').isalpha(): + abbr_end.pop() + abbr_end.append(num) + last_num = num + else: + break + else: + break + + last_num = -1 + for num in range(words_size): + if num <= last_num: + continue + + if num in abbr_begin: + word_lists.append(words[num].upper()) + num += 1 + while num < words_size: + if num in abbr_end: + word_lists.append(words[num].upper()) + last_num = num + break + else: + if words[num].encode('utf-8').isalpha(): + word_lists.append(words[num].upper()) + num += 1 + else: + word_lists.append(words[num]) + + return word_lists + + +def sentence_postprocess(words: List[Any]): + middle_lists = [] + word_lists = [] + word_item = '' + + # wash words lists + for i in words: + word = '' + if isinstance(i, str): + word = i + else: + word = i.decode('utf-8') + + if word in ['', '', '']: + continue + else: + middle_lists.append(word) + + # all chinese characters + if isAllChinese(middle_lists): + for ch in middle_lists: + word_lists.append(ch.replace(' ', '')) + + # all alpha characters + elif isAllAlpha(middle_lists): + for ch in middle_lists: + word = '' + if '@@' in ch: + word = ch.replace('@@', '') + word_item += word + else: + word_item += ch + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + + # mix characters + else: + alpha_blank = False + for ch in middle_lists: + word = '' + if isAllChinese(ch): + if alpha_blank is True: + word_lists.pop() + word_lists.append(ch) + alpha_blank = False + elif '@@' in ch: + word = ch.replace('@@', '') + word_item += word + alpha_blank = False + elif isAllAlpha(ch): + word_item += ch + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + alpha_blank = True + else: + raise ValueError('invalid character: {}'.format(ch)) + + word_lists = abbr_dispose(word_lists) + sentence = ''.join(word_lists).strip() + return sentence diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py new file mode 100644 index 000000000..d8564f29e --- /dev/null +++ b/funasr/utils/wav_utils.py @@ -0,0 +1,178 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +from typing import Any, Dict, Union + +import kaldiio +import librosa +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi + + +def ndarray_resample(audio_in: np.ndarray, + fs_in: int = 16000, + fs_out: int = 16000) -> np.ndarray: + audio_out = audio_in + if fs_in != fs_out: + audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out) + return audio_out + + +def torch_resample(audio_in: torch.Tensor, + fs_in: int = 16000, + fs_out: int = 16000) -> torch.Tensor: + audio_out = audio_in + if fs_in != fs_out: + audio_out = torchaudio.transforms.Resample(orig_freq=fs_in, + new_freq=fs_out)(audio_in) + return audio_out + + +def extract_CMVN_featrures(mvn_file): + """ + extract CMVN from cmvn.ark + """ + + if not os.path.exists(mvn_file): + return None + try: + cmvn = kaldiio.load_mat(mvn_file) + means = [] + variance = [] + + for i in range(cmvn.shape[1] - 1): + means.append(float(cmvn[0][i])) + + count = float(cmvn[0][-1]) + + for i in range(cmvn.shape[1] - 1): + variance.append(float(cmvn[1][i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + + cmvn = np.array([means, variance]) + return cmvn + except Exception: + cmvn = extract_CMVN_features_txt(mvn_file) + return cmvn + + +def extract_CMVN_features_txt(mvn_file): # noqa + with open(mvn_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + add_shift_list = [] + rescale_list = [] + for i in range(len(lines)): + line_item = lines[i].split() + if line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + add_shift_line = line_item[3:(len(line_item) - 1)] + add_shift_list = list(add_shift_line) + continue + elif line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + rescale_line = line_item[3:(len(line_item) - 1)] + rescale_list = list(rescale_line) + continue + add_shift_list_f = [float(s) for s in add_shift_list] + rescale_list_f = [float(s) for s in rescale_list] + cmvn = np.array([add_shift_list_f, rescale_list_f]) + return cmvn + + +def build_LFR_features(inputs, m=7, n=6): # noqa + """ + Actually, this implements stacking frames and skipping frames. + if m = 1 and n = 1, just return the origin features. + if m = 1 and n > 1, it works like skipping. + if m > 1 and n = 1, it works like stacking but only support right frames. + if m > 1 and n > 1, it works like LFR. + + Args: + inputs_batch: inputs is T x D np.ndarray + m: number of frames to stack + n: number of frames to skip + """ + # LFR_inputs_batch = [] + # for inputs in inputs_batch: + LFR_inputs = [] + T = inputs.shape[0] + T_lfr = int(np.ceil(T / n)) + left_padding = np.tile(inputs[0], ((m - 1) // 2, 1)) + inputs = np.vstack((left_padding, inputs)) + T = T + (m - 1) // 2 + for i in range(T_lfr): + if m <= T - i * n: + LFR_inputs.append(np.hstack(inputs[i * n:i * n + m])) + else: # process last LFR frame + num_padding = m - (T - i * n) + frame = np.hstack(inputs[i * n:]) + for _ in range(num_padding): + frame = np.hstack((frame, inputs[-1])) + LFR_inputs.append(frame) + return np.vstack(LFR_inputs) + + +def compute_fbank(wav_file, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + is_pcm=False, + fs: Union[int, Dict[Any, int]] = 16000): + audio_sr: int = 16000 + model_sr: int = 16000 + if isinstance(fs, int): + model_sr = fs + audio_sr = fs + else: + model_sr = fs['model_fs'] + audio_sr = fs['audio_fs'] + + if is_pcm is True: + # byte(PCM16) to float32, and resample + value = wav_file + middle_data = np.frombuffer(value, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + waveform = np.frombuffer( + (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + waveform = ndarray_resample(waveform, audio_sr, model_sr) + waveform = torch.from_numpy(waveform.reshape(1, -1)) + else: + # load pcm from wav, and resample + waveform, audio_sr = torchaudio.load(wav_file) + waveform = waveform * (1 << 15) + waveform = torch_resample(waveform, audio_sr, model_sr) + + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + window_type='hamming', + sample_frequency=model_sr) + + input_feats = mat + + return input_feats diff --git a/funasr/version.txt b/funasr/version.txt index 6e8bf73aa..b1e80bb24 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -0.1.0 +0.1.3