diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py new file mode 100644 index 000000000..d3975ae4c --- /dev/null +++ b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py @@ -0,0 +1,39 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +import numpy as np + +if __name__ == '__main__': + inference_sv_pipline = pipeline( + task=Tasks.speaker_verification, + model='damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch' + ) + + # extract speaker embedding + # for url use "spk_embedding" as key + rec_result = inference_sv_pipline( + audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav') + enroll = rec_result["spk_embedding"] + + # for local file use "spk_embedding" as key + rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"] + same = rec_result["spk_embedding"] + + import soundfile + wav = soundfile.read('sv_example_enroll.wav')[0] + # for raw inputs use "spk_embedding" as key + spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"] + + rec_result = inference_sv_pipline( + audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav') + different = rec_result["spk_embedding"] + + # calculate cosine similarity for same speaker + sv_threshold = 0.9465 + same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same)) + same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0 + print("Similarity:", same_cos) + + # calculate cosine similarity for different speaker + diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different)) + diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0 + print("Similarity:", diff_cos) diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py new file mode 100644 index 000000000..1151cebdc --- /dev/null +++ b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py @@ -0,0 +1,21 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +if __name__ == '__main__': + inference_sv_pipline = pipeline( + task=Tasks.speaker_verification, + model='speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch' + ) + + # the same speaker + rec_result = inference_sv_pipline(audio_in=( + 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav', + 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav')) + print("Similarity", rec_result["scores"]) + + # different speakers + rec_result = inference_sv_pipline(audio_in=( + 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav', + 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')) + + print("Similarity", rec_result["scores"]) diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py index a48088c8d..87f38013b 100644 --- a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py +++ b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py @@ -12,20 +12,20 @@ if __name__ == '__main__': # for url use "utt_id" as key rec_result = inference_sv_pipline( audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav') - enroll = rec_result["utt_id"] + enroll = rec_result["spk_embedding"] # for local file use "utt_id" as key rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"] - same = rec_result["test1"] + same = rec_result["spk_embedding"] import soundfile wav = soundfile.read('sv_example_enroll.wav')[0] # for raw inputs use "utt_id" as key - spk_embedding = inference_sv_pipline(audio_in=wav)["utt_id"] + spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"] rec_result = inference_sv_pipline( audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav') - different = rec_result["utt_id"] + different = rec_result["spk_embedding"] # 对相同的说话人计算余弦相似度 sv_threshold = 0.9465 diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py index 952ce1597..930f7e032 100644 --- a/funasr/models/encoder/resnet34_encoder.py +++ b/funasr/models/encoder/resnet34_encoder.py @@ -387,7 +387,6 @@ class ResNet34_SP_L2Reg(AbsEncoder): return var_dict_torch_update - class ResNet34Diar(ResNet34): def __init__( self, diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py index 16384a7ad..1b08c4dad 100644 --- a/funasr/tasks/sv.py +++ b/funasr/tasks/sv.py @@ -1,14 +1,18 @@ import argparse import logging +import os +from pathlib import Path from typing import Callable from typing import Collection from typing import Dict from typing import List from typing import Optional from typing import Tuple +from typing import Union import numpy as np import torch +import yaml from typeguard import check_argument_types from typeguard import check_return_type @@ -21,7 +25,7 @@ from funasr.models.e2e_asr import ESPnetASRModel from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.encoder.rnn_encoder import RNNEncoder -from funasr.models.encoder.resnet34_encoder import ResNet34 +from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg from funasr.models.pooling.statistic_pooling import StatisticPooling from funasr.models.decoder.sv_decoder import DenseDecoder from funasr.models.e2e_sv import ESPnetSVModel @@ -103,6 +107,7 @@ encoder_choices = ClassChoices( "encoder", classes=dict( resnet34=ResNet34, + resnet34_sp_l2reg=ResNet34_SP_L2Reg, rnn=RNNEncoder, ), type_check=AbsEncoder, @@ -394,9 +399,16 @@ class SVTask(AbsTask): # 7. Pooling layer pooling_class = pooling_choices.get_class(args.pooling_type) + pooling_dim = (2, 3) + eps = 1e-12 + if hasattr(args, "pooling_type_conf"): + if "pooling_dim" in args.pooling_type_conf: + pooling_dim = args.pooling_type_conf["pooling_dim"] + if "eps" in args.pooling_type_conf: + eps = args.pooling_type_conf["eps"] pooling_layer = pooling_class( - pooling_dim=(2, 3), - eps=1e-12, + pooling_dim=pooling_dim, + eps=eps, ) if args.pooling_type == "statistic": encoder_output_size *= 2 @@ -435,3 +447,95 @@ class SVTask(AbsTask): assert check_return_type(model) return model + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str] = None, + model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + device: str = "cpu", + ): + """Build model from the files. + + This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + cmvn_file: The cmvn file for front-end + device: Device type, "cpu", "cuda", or "cuda:N". + + """ + assert check_argument_types() + if config_file is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'config_file' is not specified." + ) + config_file = Path(model_file).parent / "config.yaml" + else: + config_file = Path(config_file) + + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + if cmvn_file is not None: + args["cmvn_file"] = cmvn_file + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + model.to(device) + model_dict = dict() + model_name_pth = None + if model_file is not None: + logging.info("model_file is {}".format(model_file)) + if device == "cuda": + device = f"cuda:{torch.cuda.current_device()}" + model_dir = os.path.dirname(model_file) + model_name = os.path.basename(model_file) + if "model.ckpt-" in model_name or ".bin" in model_name: + if ".bin" in model_name: + model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb')) + else: + model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name)) + if os.path.exists(model_name_pth): + logging.info("model_file is load from pth: {}".format(model_name_pth)) + model_dict = torch.load(model_name_pth, map_location=device) + else: + model_dict = cls.convert_tf2torch(model, model_file) + model.load_state_dict(model_dict) + else: + model_dict = torch.load(model_file, map_location=device) + model.load_state_dict(model_dict) + if model_name_pth is not None and not os.path.exists(model_name_pth): + torch.save(model_dict, model_name_pth) + logging.info("model_file is saved to pth: {}".format(model_name_pth)) + + return model, args + + @classmethod + def convert_tf2torch( + cls, + model, + ckpt, + ): + logging.info("start convert tf model to torch model") + from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict + var_dict_tf = load_tf_dict(ckpt) + var_dict_torch = model.state_dict() + var_dict_torch_update = dict() + # speech encoder + var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # pooling layer + var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # decoder + var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + + return var_dict_torch_update