FunASR/egs/cnceleb/resnet34/sid/speaker_verification.py
2023-01-16 18:46:40 +08:00

168 lines
5.7 KiB
Python

from __future__ import print_function
import numpy as np
import os
import kaldiio
from multiprocessing import Pool
import argparse
from tqdm import tqdm
import math
from funasr.utils.types import str2triple_str
import logging
from typing import List, Union, Tuple, Sequence
from funasr.bin.sv_inference import inference_modelscope
import soundfile
import torch
class MultiProcessRunner:
def __init__(self, fn):
self.process = fn
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--njobs", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--log_level", type=str, default="INFO")
parser.add_argument("--sr", type=int, default=16000)
task_list, shared_param, args = self.prepare(parser)
chunk_size = int(math.ceil(float(len(task_list)) / args.njobs))
if args.verbose:
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.njobs, chunk_size))
subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
for i in range(args.njobs)]
result_list = self.pool_run(subtask_list, args)
self.post(result_list, args)
def prepare(self, parser: argparse.ArgumentParser):
raise NotImplementedError("Please implement the prepare function.")
def post(self, results_list: list, args: argparse.Namespace):
raise NotImplementedError("Please implement the post function.")
def pool_run(self, tasks: list, args: argparse.Namespace):
results = []
if args.debug:
one_result = self.process(tasks[0])
results.append(one_result)
else:
pool = Pool(args.njobs)
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
results.append(one_result)
pool.close()
return results
class MyRunner(MultiProcessRunner):
def prepare(self, parser: argparse.ArgumentParser):
parser.add_argument(
"--gpu_inference",
type=bool,
default=False
)
parser.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append"
)
parser.add_argument(
"--gpu_devices",
type=lambda devices: devices.split(","),
default=None,
)
args = parser.parse_args()
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if args.gpu_inference and (args.gpu_devices is None or len(args.gpu_devices) == 0):
logging.warning("gpu_inference is set to True, but gpu_devices is not given, use CPU instead.")
args.gpu_inference = False
if args.gpu_inference:
args.njobs = args.njobs * len(args.gpu_devices)
speech_dict = {}
ref_speech_dict = {}
for _path, _name, _type in args.data_path_and_name_and_type:
if _name == "speech":
speech_dict = self.read_data_path(_path)
elif _name == "ref_speech":
ref_speech_dict = self.read_data_path(_path)
task_list, args.njobs = self.get_key_list(args.data_path_and_name_and_type, args.njobs)
return task_list, [speech_dict, ref_speech_dict], args
def read_data_path(self, file_path):
results = {}
for line in open(file_path, "r"):
key, path = line.strip().split(" ", 1)
results[key] = path
return results
def get_key_list(
self,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
njobs: int
):
first_data = data_path_and_name_and_type[0]
content = open(first_data[0], "r").readlines()
line_number = len(content)
njobs = min(njobs, line_number)
logging.warning("njobs is reduced to {}, since only {} lines exist in {}".format(
njobs, line_number, first_data[0],
))
key_list = [line.strip().split(" ", 1)[0] for line in content]
return key_list, njobs
def post(self, results_list: list, args: argparse.Namespace):
for results in results_list:
for key, value in results:
logging.info("{} {}".format(key, value))
def process(task_args):
task_id, key_list, [speech_dict, ref_speech_dict], args = task_args
if args.gpu_inference:
device = args.gpu_devices[task_id % len(args.gpu_devices)]
torch.cuda.set_device("cuda:".format(device))
inference_func = inference_modelscope(
output_dir=None,
batch_size=1,
dtype="float32",
ngpu=1 if args.gpu_inference else 0,
seed=0,
num_workers=0,
log_level=logging.INFO,
key_file=None,
sv_train_config="sv.yaml",
sv_model_file="sv.pb",
model_tag=None,
allow_variable_data_keys=True,
streaming=False,
embedding_node="resnet1_dense",
sv_threshold=0.9465,
)
results = {}
for key in key_list:
speech = soundfile.read(speech_dict[key])[0]
ref_speech = soundfile.read(ref_speech_dict[key])[0]
ret = inference_func(None, (speech, ref_speech))
results[key] = ret["value"]
return results
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()