From 4cd79db451786548d8a100f25c3b03da0eb30f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 16 May 2023 14:08:57 +0800 Subject: [PATCH] inference --- funasr/bin/asr_train_paraformer.py | 55 --- funasr/bin/asr_train_transducer.py | 46 -- funasr/bin/asr_train_uniasr.py | 46 -- funasr/bin/sa_asr_inference.py | 692 ----------------------------- funasr/bin/sa_asr_train.py | 50 --- 5 files changed, 889 deletions(-) delete mode 100755 funasr/bin/asr_train_paraformer.py delete mode 100755 funasr/bin/asr_train_transducer.py delete mode 100755 funasr/bin/asr_train_uniasr.py delete mode 100644 funasr/bin/sa_asr_inference.py delete mode 100755 funasr/bin/sa_asr_train.py diff --git a/funasr/bin/asr_train_paraformer.py b/funasr/bin/asr_train_paraformer.py deleted file mode 100755 index 223be14f4..000000000 --- a/funasr/bin/asr_train_paraformer.py +++ /dev/null @@ -1,55 +0,0 @@ -# -*- encoding: utf-8 -*- -#!/usr/bin/env python3 -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - -import os - -from funasr.tasks.asr import ASRTaskParaformer as ASRTask - - -# for ASR Training -def parse_args(): - parser = ASRTask.get_parser() - parser.add_argument( - "--mode", - type=str, - default="asr", - help="mode", - ) - parser.add_argument( - "--gpu_id", - type=int, - default=0, - help="local gpu id.", - ) - args = parser.parse_args() - return args - - -def main(args=None, cmd=None): - # for ASR Training - ASRTask.main(args=args, cmd=cmd) - - -if __name__ == '__main__': - args = parse_args() - - # setup local gpu_id - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) - - # DDP settings - if args.ngpu > 1: - args.distributed = True - else: - args.distributed = False - assert args.num_worker_count == 1 - - # re-compute batch size: when dataset type is small - if args.dataset_type == "small": - if args.batch_size is not None: - args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: - args.batch_bins = args.batch_bins * args.ngpu - - main(args=args) diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py deleted file mode 100755 index fe418dbc9..000000000 --- a/funasr/bin/asr_train_transducer.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 - -import os - -from funasr.tasks.asr import ASRTransducerTask - - -# for ASR Training -def parse_args(): - parser = ASRTransducerTask.get_parser() - parser.add_argument( - "--gpu_id", - type=int, - default=0, - help="local gpu id.", - ) - args = parser.parse_args() - return args - - -def main(args=None, cmd=None): - # for ASR Training - ASRTransducerTask.main(args=args, cmd=cmd) - - -if __name__ == '__main__': - args = parse_args() - - # setup local gpu_id - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) - - # DDP settings - if args.ngpu > 1: - args.distributed = True - else: - args.distributed = False - assert args.num_worker_count == 1 - - # re-compute batch size: when dataset type is small - if args.dataset_type == "small": - if args.batch_size is not None: - args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: - args.batch_bins = args.batch_bins * args.ngpu - - main(args=args) diff --git a/funasr/bin/asr_train_uniasr.py b/funasr/bin/asr_train_uniasr.py deleted file mode 100755 index a40b5032c..000000000 --- a/funasr/bin/asr_train_uniasr.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 - -import os - -from funasr.tasks.asr import ASRTaskUniASR - - -# for ASR Training -def parse_args(): - parser = ASRTaskUniASR.get_parser() - parser.add_argument( - "--gpu_id", - type=int, - default=0, - help="local gpu id.", - ) - args = parser.parse_args() - return args - - -def main(args=None, cmd=None): - # for ASR Training - ASRTaskUniASR.main(args=args, cmd=cmd) - - -if __name__ == '__main__': - args = parse_args() - - # setup local gpu_id - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) - - # DDP settings - if args.ngpu > 1: - args.distributed = True - else: - args.distributed = False - assert args.num_worker_count == 1 - - # re-compute batch size: when dataset type is small - if args.dataset_type == "small": - if args.batch_size is not None: - args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: - args.batch_bins = args.batch_bins * args.ngpu - - main(args=args) diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py deleted file mode 100644 index 7a5ba8313..000000000 --- a/funasr/bin/sa_asr_inference.py +++ /dev/null @@ -1,692 +0,0 @@ -# -*- encoding: utf-8 -*- -#!/usr/bin/env python3 -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - -import argparse -import logging -import sys -from pathlib import Path -from typing import Any -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict - -import numpy as np -import torch -from typeguard import check_argument_types -from typeguard import check_return_type - -from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim -from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch -from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis -from funasr.modules.scorers.ctc import CTCPrefixScorer -from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.scorers.scorer_interface import BatchScorerInterface -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.sa_asr import ASRTask -from funasr.tasks.lm import LMTask -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.token_id_converter import TokenIDConverter -from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.tasks.asr import frontend_choices - - -header_colors = '\033[95m' -end_colors = '\033[0m' - - -class Speech2Text: - """Speech2Text class - - Examples: - >>> import soundfile - >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") - >>> speech2text(audio) - [(text, token, token_int, hypothesis object), ...] - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - batch_size: int = 1, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - streaming: bool = False, - frontend_conf: dict = None, - **kwargs, - ): - assert check_argument_types() - - # 1. Build ASR model - scorers = {} - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - if asr_train_args.frontend=='wav_frontend': - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - else: - frontend_class=frontend_choices.get_class(asr_train_args.frontend) - frontend = frontend_class(**asr_train_args.frontend_conf).eval() - - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() - - decoder = asr_model.decoder - - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) - token_list = asr_model.token_list - scorers.update( - decoder=decoder, - ctc=ctc, - length_bonus=LengthBonus(len(token_list)), - ) - - # 2. Build Language model - if lm_train_config is not None: - lm, lm_train_args = LMTask.build_model_from_file( - lm_train_config, lm_file, None, device - ) - scorers["lm"] = lm.lm - - # 3. Build ngram model - # ngram is not supported now - ngram = None - scorers["ngram"] = ngram - - # 4. Build BeamSearch object - # transducer is not supported now - beam_search_transducer = None - - weights = dict( - decoder=1.0 - ctc_weight, - ctc=ctc_weight, - lm=lm_weight, - ngram=ngram_weight, - length_bonus=penalty, - ) - beam_search = BeamSearch( - beam_size=beam_size, - weights=weights, - scorers=scorers, - sos=asr_model.sos, - eos=asr_model.eos, - vocab_size=len(token_list), - token_list=token_list, - pre_beam_score_key=None if ctc_weight == 1.0 else "full", - ) - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - self.beam_search = beam_search - self.beam_search_transducer = beam_search_transducer - self.maxlenratio = maxlenratio - self.minlenratio = minlenratio - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray] - ) -> List[ - Tuple[ - Optional[str], - Optional[str], - List[str], - List[int], - Union[Hypothesis], - ] - ]: - """Inference - - Args: - speech: Input speech data - Returns: - text, text_id, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if isinstance(profile, np.ndarray): - profile = torch.tensor(profile) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - # b. Forward Encoder - asr_enc, _, spk_enc = self.asr_model.encode(**batch) - if isinstance(asr_enc, tuple): - asr_enc = asr_enc[0] - if isinstance(spk_enc, tuple): - spk_enc = spk_enc[0] - assert len(asr_enc) == 1, len(asr_enc) - assert len(spk_enc) == 1, len(spk_enc) - - # c. Passed the encoder result and the beam search - nbest_hyps = self.beam_search( - asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio - ) - - nbest_hyps = nbest_hyps[: self.nbest] - - results = [] - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1: last_pos] - else: - token_int = hyp.yseq[1: last_pos].tolist() - - spk_weigths=torch.stack(hyp.spk_weigths, dim=0) - - token_ori = self.converter.ids2tokens(token_int) - text_ori = self.tokenizer.tokens2text(token_ori) - - text_ori_spklist = text_ori.split('$') - cur_index = 0 - spk_choose = [] - for i in range(len(text_ori_spklist)): - text_ori_split = text_ori_spklist[i] - n = len(text_ori_split) - spk_weights_local = spk_weigths[cur_index: cur_index + n] - cur_index = cur_index + n + 1 - spk_weights_local = spk_weights_local.mean(dim=0) - spk_choose_local = spk_weights_local.argmax(-1) - spk_choose.append(spk_choose_local.item() + 1) - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - text_spklist = text.split('$') - assert len(spk_choose) == len(text_spklist) - - spk_list=[] - for i in range(len(text_spklist)): - text_split = text_spklist[i] - n = len(text_split) - spk_list.append(str(spk_choose[i]) * n) - - text_id = '$'.join(spk_list) - - assert len(text) == len(text_id) - - results.append((text, text_id, token, token_int, hyp)) - - assert check_return_type(results) - return results - -def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - mc: bool = False, - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - mc=mc, - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - -def inference_modelscope( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - # data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - mc: bool = False, - param_dict: dict = None, - **kwargs, -): - assert check_argument_types() - if batch_size > 1: - raise NotImplementedError("batch decoding is not implemented") - if word_lm_train_config is not None: - raise NotImplementedError("Word LM is not implemented") - if ngpu > 1: - raise NotImplementedError("only single GPU decoding is supported") - - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - - logging.basicConfig( - level=log_level, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - - if ngpu >= 1 and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # 1. Set random-seed - set_all_random_seed(seed) - - # 2. Build speech2text - speech2text_kwargs = dict( - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - bpemodel=bpemodel, - device=device, - maxlenratio=maxlenratio, - minlenratio=minlenratio, - dtype=dtype, - beam_size=beam_size, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - ngram_weight=ngram_weight, - penalty=penalty, - nbest=nbest, - streaming=streaming, - ) - logging.info("speech2text_kwargs: {}".format(speech2text_kwargs)) - speech2text = Speech2Text(**speech2text_kwargs) - - def _forward(data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, - **kwargs, - ): - # 3. Build data-iterator - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - mc=mc, - batch_size=batch_size, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) - - finish_count = 0 - file_count = 1 - # 7 .Start for-loop - # FIXME(kamo): The output format should be discussed about - asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - if output_path is not None: - writer = DatadirWriter(output_path) - else: - writer = None - - for keys, batch in loader: - assert isinstance(batch, dict), type(batch) - assert all(isinstance(s, str) for s in keys), keys - _bs = len(next(iter(batch.values()))) - assert len(keys) == _bs, f"{len(keys)} != {_bs}" - # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} - # N-best list of (text, token, token_int, hyp_object) - try: - results = speech2text(**batch) - except TooShortUttError as e: - logging.warning(f"Utterance {keys} {e}") - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp]] * nbest - - # Only supporting batch_size==1 - key = keys[0] - for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results): - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] - - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - ibest_writer["text_id"][key] = text_id - - if text is not None: - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - item = {'key': key, 'value': text_postprocessed} - asr_result_list.append(item) - finish_count += 1 - asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = text - - logging.info("uttid: {}".format(key)) - logging.info("text predictions: {}".format(text)) - logging.info("text_id predictions: {}\n".format(text_id)) - return asr_result_list - - return _forward - -def get_parser(): - parser = config_argparse.ArgumentParser( - description="ASR Decoding", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. - parser.add_argument( - "--log_level", - type=lambda x: x.upper(), - default="INFO", - choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), - help="The verbose level of logging", - ) - - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument( - "--ngpu", - type=int, - default=0, - help="The number of gpus. 0 indicates CPU mode", - ) - parser.add_argument( - "--gpuid_list", - type=str, - default="", - help="The visible gpus", - ) - parser.add_argument("--seed", type=int, default=0, help="Random seed") - parser.add_argument( - "--dtype", - default="float32", - choices=["float16", "float32", "float64"], - help="Data type", - ) - parser.add_argument( - "--num_workers", - type=int, - default=1, - help="The number of workers used for DataLoader", - ) - - group = parser.add_argument_group("Input data related") - group.add_argument( - "--data_path_and_name_and_type", - type=str2triple_str, - required=False, - action="append", - ) - group.add_argument("--raw_inputs", type=list, default=None) - # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) - group.add_argument("--key_file", type=str_or_none) - group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) - - group = parser.add_argument_group("The model configuration related") - group.add_argument( - "--asr_train_config", - type=str, - help="ASR training configuration", - ) - group.add_argument( - "--asr_model_file", - type=str, - help="ASR model parameter file", - ) - group.add_argument( - "--cmvn_file", - type=str, - help="Global cmvn file", - ) - group.add_argument( - "--lm_train_config", - type=str, - help="LM training configuration", - ) - group.add_argument( - "--lm_file", - type=str, - help="LM parameter file", - ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) - group.add_argument( - "--model_tag", - type=str, - help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", - ) - - group = parser.add_argument_group("Beam-search related") - group.add_argument( - "--batch_size", - type=int, - default=1, - help="The batch size for inference", - ) - group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.5, - help="CTC weight in joint decoding", - ) - group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - - group = parser.add_argument_group("Text converter related") - group.add_argument( - "--token_type", - type=str_or_none, - default=None, - choices=["char", "bpe", None], - help="The token type for ASR model. " - "If not given, refers from the training args", - ) - group.add_argument( - "--bpemodel", - type=str_or_none, - default=None, - help="The model path of sentencepiece. " - "If not given, refers from the training args", - ) - - return parser - - -def main(cmd=None): - print(get_commandline_args(), file=sys.stderr) - parser = get_parser() - args = parser.parse_args(cmd) - kwargs = vars(args) - kwargs.pop("config", None) - inference(**kwargs) - - -if __name__ == "__main__": - main() diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py deleted file mode 100755 index 67106cf48..000000000 --- a/funasr/bin/sa_asr_train.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- encoding: utf-8 -*- -#!/usr/bin/env python3 -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - -import os - -from funasr.tasks.sa_asr import ASRTask - - -# for ASR Training -def parse_args(): - parser = ASRTask.get_parser() - parser.add_argument( - "--gpu_id", - type=int, - default=0, - help="local gpu id.", - ) - args = parser.parse_args() - return args - - -def main(args=None, cmd=None): - # for ASR Training - ASRTask.main(args=args, cmd=cmd) - - -if __name__ == '__main__': - args = parse_args() - - # setup local gpu_id - if args.ngpu > 0: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) - - # DDP settings - if args.ngpu > 1: - args.distributed = True - else: - args.distributed = False - assert args.num_worker_count == 1 - - # re-compute batch size: when dataset type is small - if args.dataset_type == "small": - if args.batch_size is not None and args.ngpu > 0: - args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None and args.ngpu > 0: - args.batch_bins = args.batch_bins * args.ngpu - - main(args=args)