Merge pull request #207 from alibaba-damo-academy/dev_dzh

Dev dzh
This commit is contained in:
zhifu gao 2023-03-10 18:24:39 +08:00 committed by GitHub
commit 9be8a443d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 6004 additions and 22 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,97 @@
from funasr.bin.diar_inference_launch import inference_launch
import os
def test_fbank_cpu_infer():
diar_config_path = "sond_fbank.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
num_workers=0,
log_level="INFO",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_fbank_gpu_infer():
diar_config_path = "sond_fbank.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="INFO",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_wav_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_wav.scp", "speech", "sound"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_without_profile_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
raw_inputs = [[
"data/unit_test/raw_inputs/record.wav",
"data/unit_test/raw_inputs/spk1.wav",
"data/unit_test/raw_inputs/spk2.wav",
"data/unit_test/raw_inputs/spk3.wav",
"data/unit_test/raw_inputs/spk4.wav"
]]
pipeline = inference_launch(
mode="sond_demo",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
param_dict={},
)
results = pipeline(raw_inputs=raw_inputs)
print(results)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
test_fbank_cpu_infer()
# test_fbank_gpu_infer()
# test_wav_gpu_infer()
# test_without_profile_gpu_infer()

View File

@ -0,0 +1,39 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import numpy as np
if __name__ == '__main__':
inference_sv_pipline = pipeline(
task=Tasks.speaker_verification,
model='damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
)
# extract speaker embedding
# for url use "spk_embedding" as key
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/sv_example_enroll.wav')
enroll = rec_result["spk_embedding"]
# for local file use "spk_embedding" as key
rec_result = inference_sv_pipline(audio_in='example/sv_example_same.wav')
same = rec_result["spk_embedding"]
import soundfile
wav = soundfile.read('example/sv_example_enroll.wav')[0]
# for raw inputs use "spk_embedding" as key
spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/sv_example_different.wav')
different = rec_result["spk_embedding"]
# calculate cosine similarity for same speaker
sv_threshold = 0.80
same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
print("Similarity:", same_cos)
# calculate cosine similarity for different speaker
diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
print("Similarity:", diff_cos)

View File

@ -0,0 +1,21 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
if __name__ == '__main__':
inference_sv_pipline = pipeline(
task=Tasks.speaker_verification,
model='damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
)
# the same speaker
rec_result = inference_sv_pipline(audio_in=(
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/sv_example_enroll.wav',
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/sv_example_same.wav'))
print("Similarity", rec_result["scores"])
# different speakers
rec_result = inference_sv_pipline(audio_in=(
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/sv_example_enroll.wav',
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/sv_example_different.wav'))
print("Similarity", rec_result["scores"])

View File

@ -9,14 +9,20 @@ if __name__ == '__main__':
)
# 提取不同句子的说话人嵌入码
# for url use "utt_id" as key
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
enroll = rec_result["spk_embedding"]
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav')
# for local file use "utt_id" as key
rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
same = rec_result["spk_embedding"]
import soundfile
wav = soundfile.read('sv_example_enroll.wav')[0]
# for raw inputs use "utt_id" as key
spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
different = rec_result["spk_embedding"]

View File

@ -342,6 +342,7 @@ class DiarSondModel(AbsESPnetModel):
if isinstance(self.ci_scorer, AbsEncoder):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
else:
ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)

View File

@ -137,12 +137,12 @@ class ConvEncoder(AbsEncoder):
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv_out = nn.Conv1d(
num_units,
num_units,
out_units,
kernel_size,
)
if self.out_norm:
self.after_norm = LayerNorm(num_units)
self.after_norm = LayerNorm(out_units)
def output_size(self) -> int:
return self.num_units

View File

@ -272,7 +272,7 @@ class SelfAttentionEncoder(AbsEncoder):
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
xs_pad = xs_pad * self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (

View File

@ -387,7 +387,6 @@ class ResNet34_SP_L2Reg(AbsEncoder):
return var_dict_torch_update
class ResNet34Diar(ResNet34):
def __init__(
self,
@ -613,3 +612,230 @@ class ResNet34Diar(ResNet34):
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
def __init__(
self,
input_size,
embedding_node="resnet1_dense",
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
num_nodes_resnet1=256,
num_nodes_last_layer=256,
pooling_type="window_shift",
pool_size=20,
stride=1,
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
super(ResNet34SpL2RegDiar, self).__init__(
input_size,
use_head_conv=use_head_conv,
batchnorm_momentum=batchnorm_momentum,
use_head_maxpool=use_head_maxpool,
num_nodes_pooling_layer=num_nodes_pooling_layer,
layers_in_block=layers_in_block,
filters_in_block=filters_in_block,
)
self.embedding_node = embedding_node
self.num_nodes_resnet1 = num_nodes_resnet1
self.num_nodes_last_layer = num_nodes_last_layer
self.pooling_type = pooling_type
self.pool_size = pool_size
self.stride = stride
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
def output_size(self) -> int:
if self.embedding_node.startswith("resnet1"):
return self.num_nodes_resnet1
elif self.embedding_node.startswith("resnet2"):
return self.num_nodes_last_layer
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
endpoints = OrderedDict()
res_out, ilens = super().forward(xs_pad, ilens)
endpoints["resnet0_bn"] = res_out
if self.pooling_type == "frame_gsp":
features = statistic_pooling(res_out, ilens, (2, ))
else:
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, ), self.pool_size, self.stride)
features = features.transpose(1, 2)
endpoints["pooling"] = features
features = self.resnet1_dense(features)
endpoints["resnet1_dense"] = features
features = F.relu(features)
endpoints["resnet1_relu"] = features
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet1_bn"] = features
features = self.resnet2_dense(features)
endpoints["resnet2_dense"] = features
features = F.relu(features)
endpoints["resnet2_relu"] = features
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
train_steps = 720000
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
}
for layer_idx in range(3):
map_dict_local.update({
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
},
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
})
for block_idx in range(len(self.layers_in_block)):
for layer_idx in range(self.layers_in_block[block_idx]):
for i in ["1", "2", "_sc"]:
map_dict_local.update({
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
if name in map_dict:
if "num_batches_tracked" not in name:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu")
logging.info("torch tensor: {}, manually assigning to: {}".format(
name, map_dict[name]
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@ -82,13 +82,16 @@ def windowed_statistic_pooling(
tt = xs_pad.shape[2]
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
if xs_pad.shape == 4:
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
else:
features = F.pad(xs_pad, (pad, pad), "reflect")
stat_list = []
for i in range(num_chunk):
# B x C
st, ed = i*pooling_stride, i*pooling_stride+pooling_size
stat = statistic_pooling(features[:, :, st: ed, :], pooling_dim=pooling_dim)
stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim)
stat_list.append(stat.unsqueeze(2))
# B x C x T

View File

@ -23,7 +23,7 @@ from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
@ -122,6 +122,7 @@ encoder_choices = ClassChoices(
fsmn=FsmnEncoder,
conv=ConvEncoder,
resnet34=ResNet34Diar,
resnet34_sp_l2reg=ResNet34SpL2RegDiar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
@ -160,6 +161,7 @@ ci_scorer_choices = ClassChoices(
classes=dict(
dot=DotScorer,
cosine=CosScorer,
conv=ConvEncoder,
),
type_check=torch.nn.Module,
default=None,
@ -571,19 +573,24 @@ class DiarTask(AbsTask):
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
if model.encoder is not None:
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# speaker encoder
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
if model.speaker_encoder is not None:
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# cd scorer
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
if model.cd_scorer is not None:
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# ci scorer
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
if model.ci_scorer is not None:
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
if model.decoder is not None:
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update

View File

@ -1,14 +1,18 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
@ -21,7 +25,7 @@ from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34
from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
from funasr.models.pooling.statistic_pooling import StatisticPooling
from funasr.models.decoder.sv_decoder import DenseDecoder
from funasr.models.e2e_sv import ESPnetSVModel
@ -103,6 +107,7 @@ encoder_choices = ClassChoices(
"encoder",
classes=dict(
resnet34=ResNet34,
resnet34_sp_l2reg=ResNet34_SP_L2Reg,
rnn=RNNEncoder,
),
type_check=AbsEncoder,
@ -394,9 +399,16 @@ class SVTask(AbsTask):
# 7. Pooling layer
pooling_class = pooling_choices.get_class(args.pooling_type)
pooling_dim = (2, 3)
eps = 1e-12
if hasattr(args, "pooling_type_conf"):
if "pooling_dim" in args.pooling_type_conf:
pooling_dim = args.pooling_type_conf["pooling_dim"]
if "eps" in args.pooling_type_conf:
eps = args.pooling_type_conf["eps"]
pooling_layer = pooling_class(
pooling_dim=(2, 3),
eps=1e-12,
pooling_dim=pooling_dim,
eps=eps,
)
if args.pooling_type == "statistic":
encoder_output_size *= 2
@ -435,3 +447,95 @@ class SVTask(AbsTask):
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
model_name_pth = None
if model_file is not None:
logging.info("model_file is {}".format(model_file))
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
model_dir = os.path.dirname(model_file)
model_name = os.path.basename(model_file)
if "model.ckpt-" in model_name or ".bin" in model_name:
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args
@classmethod
def convert_tf2torch(
cls,
model,
ckpt,
):
logging.info("start convert tf model to torch model")
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# pooling layer
var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update