mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
168 lines
5.7 KiB
Python
168 lines
5.7 KiB
Python
from __future__ import print_function
|
|
import numpy as np
|
|
import os
|
|
import kaldiio
|
|
from multiprocessing import Pool
|
|
import argparse
|
|
from tqdm import tqdm
|
|
import math
|
|
from funasr.utils.types import str2triple_str
|
|
import logging
|
|
from typing import List, Union, Tuple, Sequence
|
|
from funasr.bin.sv_inference import inference_modelscope
|
|
import soundfile
|
|
import torch
|
|
|
|
|
|
class MultiProcessRunner:
|
|
def __init__(self, fn):
|
|
self.process = fn
|
|
|
|
def run(self):
|
|
parser = argparse.ArgumentParser("")
|
|
# Task-independent options
|
|
parser.add_argument("--njobs", type=int, default=16)
|
|
parser.add_argument("--debug", action="store_true", default=False)
|
|
parser.add_argument("--no_pbar", action="store_true", default=False)
|
|
parser.add_argument("--verbose", action="store_true", default=False)
|
|
parser.add_argument("--log_level", type=str, default="INFO")
|
|
parser.add_argument("--sr", type=int, default=16000)
|
|
|
|
task_list, shared_param, args = self.prepare(parser)
|
|
chunk_size = int(math.ceil(float(len(task_list)) / args.njobs))
|
|
if args.verbose:
|
|
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.njobs, chunk_size))
|
|
subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
|
|
for i in range(args.njobs)]
|
|
result_list = self.pool_run(subtask_list, args)
|
|
self.post(result_list, args)
|
|
|
|
def prepare(self, parser: argparse.ArgumentParser):
|
|
raise NotImplementedError("Please implement the prepare function.")
|
|
|
|
def post(self, results_list: list, args: argparse.Namespace):
|
|
raise NotImplementedError("Please implement the post function.")
|
|
|
|
def pool_run(self, tasks: list, args: argparse.Namespace):
|
|
results = []
|
|
if args.debug:
|
|
one_result = self.process(tasks[0])
|
|
results.append(one_result)
|
|
else:
|
|
pool = Pool(args.njobs)
|
|
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
|
|
results.append(one_result)
|
|
pool.close()
|
|
|
|
return results
|
|
|
|
|
|
class MyRunner(MultiProcessRunner):
|
|
def prepare(self, parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--gpu_inference",
|
|
type=bool,
|
|
default=False
|
|
)
|
|
parser.add_argument(
|
|
"--data_path_and_name_and_type",
|
|
type=str2triple_str,
|
|
required=True,
|
|
action="append"
|
|
)
|
|
parser.add_argument(
|
|
"--gpu_devices",
|
|
type=lambda devices: devices.split(","),
|
|
default=None,
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(
|
|
level=args.log_level,
|
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
)
|
|
|
|
if args.gpu_inference and (args.gpu_devices is None or len(args.gpu_devices) == 0):
|
|
logging.warning("gpu_inference is set to True, but gpu_devices is not given, use CPU instead.")
|
|
args.gpu_inference = False
|
|
|
|
if args.gpu_inference:
|
|
args.njobs = args.njobs * len(args.gpu_devices)
|
|
|
|
speech_dict = {}
|
|
ref_speech_dict = {}
|
|
for _path, _name, _type in args.data_path_and_name_and_type:
|
|
if _name == "speech":
|
|
speech_dict = self.read_data_path(_path)
|
|
elif _name == "ref_speech":
|
|
ref_speech_dict = self.read_data_path(_path)
|
|
|
|
task_list, args.njobs = self.get_key_list(args.data_path_and_name_and_type, args.njobs)
|
|
|
|
return task_list, [speech_dict, ref_speech_dict], args
|
|
|
|
def read_data_path(self, file_path):
|
|
results = {}
|
|
for line in open(file_path, "r"):
|
|
key, path = line.strip().split(" ", 1)
|
|
results[key] = path
|
|
|
|
return results
|
|
|
|
def get_key_list(
|
|
self,
|
|
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
|
njobs: int
|
|
):
|
|
first_data = data_path_and_name_and_type[0]
|
|
content = open(first_data[0], "r").readlines()
|
|
line_number = len(content)
|
|
njobs = min(njobs, line_number)
|
|
logging.warning("njobs is reduced to {}, since only {} lines exist in {}".format(
|
|
njobs, line_number, first_data[0],
|
|
))
|
|
key_list = [line.strip().split(" ", 1)[0] for line in content]
|
|
return key_list, njobs
|
|
|
|
def post(self, results_list: list, args: argparse.Namespace):
|
|
for results in results_list:
|
|
for key, value in results:
|
|
logging.info("{} {}".format(key, value))
|
|
|
|
|
|
def process(task_args):
|
|
task_id, key_list, [speech_dict, ref_speech_dict], args = task_args
|
|
if args.gpu_inference:
|
|
device = args.gpu_devices[task_id % len(args.gpu_devices)]
|
|
torch.cuda.set_device("cuda:".format(device))
|
|
inference_func = inference_modelscope(
|
|
output_dir=None,
|
|
batch_size=1,
|
|
dtype="float32",
|
|
ngpu=1 if args.gpu_inference else 0,
|
|
seed=0,
|
|
num_workers=0,
|
|
log_level=logging.INFO,
|
|
key_file=None,
|
|
sv_train_config="sv.yaml",
|
|
sv_model_file="sv.pb",
|
|
model_tag=None,
|
|
allow_variable_data_keys=True,
|
|
streaming=False,
|
|
embedding_node="resnet1_dense",
|
|
sv_threshold=0.9465,
|
|
)
|
|
results = {}
|
|
for key in key_list:
|
|
speech = soundfile.read(speech_dict[key])[0]
|
|
ref_speech = soundfile.read(ref_speech_dict[key])[0]
|
|
ret = inference_func(None, (speech, ref_speech))
|
|
results[key] = ret["value"]
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == '__main__':
|
|
my_runner = MyRunner(process)
|
|
my_runner.run()
|