mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* update * update setup * update setup * update setup * update setup * update setup * update setup * update * update * update setup
253 lines
7.8 KiB
Python
Executable File
253 lines
7.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- encoding: utf-8 -*-
|
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
|
|
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
|
from funasr.utils import config_argparse
|
|
from funasr.utils.cli_utils import get_commandline_args
|
|
from funasr.utils.types import str2triple_str
|
|
from funasr.utils.types import str_or_none
|
|
|
|
|
|
def inference_punc(
|
|
batch_size: int,
|
|
dtype: str,
|
|
ngpu: int,
|
|
seed: int,
|
|
num_workers: int,
|
|
log_level: Union[int, str],
|
|
key_file: Optional[str],
|
|
train_config: Optional[str],
|
|
model_file: Optional[str],
|
|
output_dir: Optional[str] = None,
|
|
param_dict: dict = None,
|
|
**kwargs,
|
|
):
|
|
logging.basicConfig(
|
|
level=log_level,
|
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
)
|
|
|
|
if ngpu >= 1 and torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
|
|
# 1. Set random-seed
|
|
set_all_random_seed(seed)
|
|
text2punc = Text2Punc(train_config, model_file, device)
|
|
|
|
def _forward(
|
|
data_path_and_name_and_type,
|
|
raw_inputs: Union[List[Any], bytes, str] = None,
|
|
output_dir_v2: Optional[str] = None,
|
|
cache: List[Any] = None,
|
|
param_dict: dict = None,
|
|
):
|
|
results = []
|
|
split_size = 20
|
|
|
|
if raw_inputs != None:
|
|
line = raw_inputs.strip()
|
|
key = "demo"
|
|
if line == "":
|
|
item = {'key': key, 'value': ""}
|
|
results.append(item)
|
|
return results
|
|
result, _ = text2punc(line)
|
|
item = {'key': key, 'value': result}
|
|
results.append(item)
|
|
return results
|
|
|
|
for inference_text, _, _ in data_path_and_name_and_type:
|
|
with open(inference_text, "r", encoding="utf-8") as fin:
|
|
for line in fin:
|
|
line = line.strip()
|
|
segs = line.split("\t")
|
|
if len(segs) != 2:
|
|
continue
|
|
key = segs[0]
|
|
if len(segs[1]) == 0:
|
|
continue
|
|
result, _ = text2punc(segs[1])
|
|
item = {'key': key, 'value': result}
|
|
results.append(item)
|
|
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
|
if output_path != None:
|
|
output_file_name = "infer.out"
|
|
Path(output_path).mkdir(parents=True, exist_ok=True)
|
|
output_file_path = (Path(output_path) / output_file_name).absolute()
|
|
with open(output_file_path, "w", encoding="utf-8") as fout:
|
|
for item_i in results:
|
|
key_out = item_i["key"]
|
|
value_out = item_i["value"]
|
|
fout.write(f"{key_out}\t{value_out}\n")
|
|
return results
|
|
|
|
return _forward
|
|
|
|
|
|
def inference_punc_vad_realtime(
|
|
batch_size: int,
|
|
dtype: str,
|
|
ngpu: int,
|
|
seed: int,
|
|
num_workers: int,
|
|
log_level: Union[int, str],
|
|
# cache: list,
|
|
key_file: Optional[str],
|
|
train_config: Optional[str],
|
|
model_file: Optional[str],
|
|
output_dir: Optional[str] = None,
|
|
param_dict: dict = None,
|
|
**kwargs,
|
|
):
|
|
ncpu = kwargs.get("ncpu", 1)
|
|
torch.set_num_threads(ncpu)
|
|
|
|
if ngpu >= 1 and torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
|
|
# 1. Set random-seed
|
|
set_all_random_seed(seed)
|
|
text2punc = Text2PuncVADRealtime(train_config, model_file, device)
|
|
|
|
def _forward(
|
|
data_path_and_name_and_type,
|
|
raw_inputs: Union[List[Any], bytes, str] = None,
|
|
output_dir_v2: Optional[str] = None,
|
|
cache: List[Any] = None,
|
|
param_dict: dict = None,
|
|
):
|
|
results = []
|
|
split_size = 10
|
|
cache_in = param_dict["cache"]
|
|
if raw_inputs != None:
|
|
line = raw_inputs.strip()
|
|
key = "demo"
|
|
if line == "":
|
|
item = {'key': key, 'value': ""}
|
|
results.append(item)
|
|
return results
|
|
result, _, cache = text2punc(line, cache_in)
|
|
param_dict["cache"] = cache
|
|
item = {'key': key, 'value': result}
|
|
results.append(item)
|
|
return results
|
|
|
|
return results
|
|
|
|
return _forward
|
|
|
|
|
|
def inference_launch(mode, **kwargs):
|
|
if mode == "punc":
|
|
return inference_punc(**kwargs)
|
|
if mode == "punc_VadRealtime":
|
|
return inference_punc_vad_realtime(**kwargs)
|
|
else:
|
|
logging.info("Unknown decoding mode: {}".format(mode))
|
|
return None
|
|
|
|
|
|
def get_parser():
|
|
parser = config_argparse.ArgumentParser(
|
|
description="Punctuation inference",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--log_level",
|
|
type=lambda x: x.upper(),
|
|
default="INFO",
|
|
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
|
help="The verbose level of logging",
|
|
)
|
|
parser.add_argument("--output_dir", type=str, required=True)
|
|
parser.add_argument("--gpuid_list", type=str, required=True)
|
|
parser.add_argument(
|
|
"--ngpu",
|
|
type=int,
|
|
default=0,
|
|
help="The number of gpus. 0 indicates CPU mode",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
|
parser.add_argument("--njob", type=int, default=1, help="Random seed")
|
|
parser.add_argument(
|
|
"--dtype",
|
|
default="float32",
|
|
choices=["float16", "float32", "float64"],
|
|
help="Data type",
|
|
)
|
|
parser.add_argument(
|
|
"--num_workers",
|
|
type=int,
|
|
default=1,
|
|
help="The number of workers used for DataLoader",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size",
|
|
type=int,
|
|
default=1,
|
|
help="The batch size for inference",
|
|
)
|
|
|
|
group = parser.add_argument_group("Input data related")
|
|
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
|
|
group.add_argument("--raw_inputs", type=str, required=False)
|
|
group.add_argument("--key_file", type=str_or_none)
|
|
group.add_argument("--cache", type=list, required=False)
|
|
group.add_argument("--param_dict", type=dict, required=False)
|
|
group = parser.add_argument_group("The model configuration related")
|
|
group.add_argument("--train_config", type=str)
|
|
group.add_argument("--model_file", type=str)
|
|
group.add_argument("--mode", type=str, default="punc")
|
|
return parser
|
|
|
|
|
|
def main(cmd=None):
|
|
print(get_commandline_args(), file=sys.stderr)
|
|
parser = get_parser()
|
|
args = parser.parse_args(cmd)
|
|
kwargs = vars(args)
|
|
kwargs.pop("config", None)
|
|
|
|
# set logging messages
|
|
logging.basicConfig(
|
|
level=args.log_level,
|
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
)
|
|
logging.info("Decoding args: {}".format(kwargs))
|
|
|
|
# gpu setting
|
|
if args.ngpu > 0:
|
|
jobid = int(args.output_dir.split(".")[-1])
|
|
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
|
|
|
|
kwargs.pop("gpuid_list", None)
|
|
kwargs.pop("njob", None)
|
|
inference_pipeline = inference_launch(**kwargs)
|
|
return inference_pipeline(kwargs["data_path_and_name_and_type"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|