mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add en sv model
This commit is contained in:
parent
31dda90f2d
commit
777ae05adb
@ -0,0 +1,39 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
import numpy as np
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference_sv_pipline = pipeline(
|
||||
task=Tasks.speaker_verification,
|
||||
model='damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
|
||||
)
|
||||
|
||||
# extract speaker embedding
|
||||
# for url use "spk_embedding" as key
|
||||
rec_result = inference_sv_pipline(
|
||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
|
||||
enroll = rec_result["spk_embedding"]
|
||||
|
||||
# for local file use "spk_embedding" as key
|
||||
rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
|
||||
same = rec_result["spk_embedding"]
|
||||
|
||||
import soundfile
|
||||
wav = soundfile.read('sv_example_enroll.wav')[0]
|
||||
# for raw inputs use "spk_embedding" as key
|
||||
spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
|
||||
|
||||
rec_result = inference_sv_pipline(
|
||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
|
||||
different = rec_result["spk_embedding"]
|
||||
|
||||
# calculate cosine similarity for same speaker
|
||||
sv_threshold = 0.9465
|
||||
same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
|
||||
same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
|
||||
print("Similarity:", same_cos)
|
||||
|
||||
# calculate cosine similarity for different speaker
|
||||
diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
|
||||
diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
|
||||
print("Similarity:", diff_cos)
|
||||
@ -0,0 +1,21 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference_sv_pipline = pipeline(
|
||||
task=Tasks.speaker_verification,
|
||||
model='speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
|
||||
)
|
||||
|
||||
# the same speaker
|
||||
rec_result = inference_sv_pipline(audio_in=(
|
||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
|
||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
|
||||
print("Similarity", rec_result["scores"])
|
||||
|
||||
# different speakers
|
||||
rec_result = inference_sv_pipline(audio_in=(
|
||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
|
||||
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
|
||||
|
||||
print("Similarity", rec_result["scores"])
|
||||
@ -12,20 +12,20 @@ if __name__ == '__main__':
|
||||
# for url use "utt_id" as key
|
||||
rec_result = inference_sv_pipline(
|
||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
|
||||
enroll = rec_result["utt_id"]
|
||||
enroll = rec_result["spk_embedding"]
|
||||
|
||||
# for local file use "utt_id" as key
|
||||
rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
|
||||
same = rec_result["test1"]
|
||||
same = rec_result["spk_embedding"]
|
||||
|
||||
import soundfile
|
||||
wav = soundfile.read('sv_example_enroll.wav')[0]
|
||||
# for raw inputs use "utt_id" as key
|
||||
spk_embedding = inference_sv_pipline(audio_in=wav)["utt_id"]
|
||||
spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
|
||||
|
||||
rec_result = inference_sv_pipline(
|
||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
|
||||
different = rec_result["utt_id"]
|
||||
different = rec_result["spk_embedding"]
|
||||
|
||||
# 对相同的说话人计算余弦相似度
|
||||
sv_threshold = 0.9465
|
||||
|
||||
@ -387,7 +387,6 @@ class ResNet34_SP_L2Reg(AbsEncoder):
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
|
||||
class ResNet34Diar(ResNet34):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
@ -21,7 +25,7 @@ from funasr.models.e2e_asr import ESPnetASRModel
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
|
||||
from funasr.models.pooling.statistic_pooling import StatisticPooling
|
||||
from funasr.models.decoder.sv_decoder import DenseDecoder
|
||||
from funasr.models.e2e_sv import ESPnetSVModel
|
||||
@ -103,6 +107,7 @@ encoder_choices = ClassChoices(
|
||||
"encoder",
|
||||
classes=dict(
|
||||
resnet34=ResNet34,
|
||||
resnet34_sp_l2reg=ResNet34_SP_L2Reg,
|
||||
rnn=RNNEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
@ -394,9 +399,16 @@ class SVTask(AbsTask):
|
||||
|
||||
# 7. Pooling layer
|
||||
pooling_class = pooling_choices.get_class(args.pooling_type)
|
||||
pooling_dim = (2, 3)
|
||||
eps = 1e-12
|
||||
if hasattr(args, "pooling_type_conf"):
|
||||
if "pooling_dim" in args.pooling_type_conf:
|
||||
pooling_dim = args.pooling_type_conf["pooling_dim"]
|
||||
if "eps" in args.pooling_type_conf:
|
||||
eps = args.pooling_type_conf["eps"]
|
||||
pooling_layer = pooling_class(
|
||||
pooling_dim=(2, 3),
|
||||
eps=1e-12,
|
||||
pooling_dim=pooling_dim,
|
||||
eps=eps,
|
||||
)
|
||||
if args.pooling_type == "statistic":
|
||||
encoder_output_size *= 2
|
||||
@ -435,3 +447,95 @@ class SVTask(AbsTask):
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
||||
@classmethod
|
||||
def build_model_from_file(
|
||||
cls,
|
||||
config_file: Union[Path, str] = None,
|
||||
model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Build model from the files.
|
||||
|
||||
This method is used for inference or fine-tuning.
|
||||
|
||||
Args:
|
||||
config_file: The yaml file saved when training.
|
||||
model_file: The model file saved when training.
|
||||
cmvn_file: The cmvn file for front-end
|
||||
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if config_file is None:
|
||||
assert model_file is not None, (
|
||||
"The argument 'model_file' must be provided "
|
||||
"if the argument 'config_file' is not specified."
|
||||
)
|
||||
config_file = Path(model_file).parent / "config.yaml"
|
||||
else:
|
||||
config_file = Path(config_file)
|
||||
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
if cmvn_file is not None:
|
||||
args["cmvn_file"] = cmvn_file
|
||||
args = argparse.Namespace(**args)
|
||||
model = cls.build_model(args)
|
||||
if not isinstance(model, AbsESPnetModel):
|
||||
raise RuntimeError(
|
||||
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
|
||||
)
|
||||
model.to(device)
|
||||
model_dict = dict()
|
||||
model_name_pth = None
|
||||
if model_file is not None:
|
||||
logging.info("model_file is {}".format(model_file))
|
||||
if device == "cuda":
|
||||
device = f"cuda:{torch.cuda.current_device()}"
|
||||
model_dir = os.path.dirname(model_file)
|
||||
model_name = os.path.basename(model_file)
|
||||
if "model.ckpt-" in model_name or ".bin" in model_name:
|
||||
if ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
|
||||
else:
|
||||
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
else:
|
||||
model_dict = cls.convert_tf2torch(model, model_file)
|
||||
model.load_state_dict(model_dict)
|
||||
else:
|
||||
model_dict = torch.load(model_file, map_location=device)
|
||||
model.load_state_dict(model_dict)
|
||||
if model_name_pth is not None and not os.path.exists(model_name_pth):
|
||||
torch.save(model_dict, model_name_pth)
|
||||
logging.info("model_file is saved to pth: {}".format(model_name_pth))
|
||||
|
||||
return model, args
|
||||
|
||||
@classmethod
|
||||
def convert_tf2torch(
|
||||
cls,
|
||||
model,
|
||||
ckpt,
|
||||
):
|
||||
logging.info("start convert tf model to torch model")
|
||||
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
|
||||
var_dict_tf = load_tf_dict(ckpt)
|
||||
var_dict_torch = model.state_dict()
|
||||
var_dict_torch_update = dict()
|
||||
# speech encoder
|
||||
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# pooling layer
|
||||
var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# decoder
|
||||
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
Loading…
Reference in New Issue
Block a user