Merge pull request #95 from alibaba-damo-academy/dev_dzh

Add sound model
This commit is contained in:
zhifu gao 2023-02-10 19:32:39 +08:00 committed by GitHub
commit 60aef2aa96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 9488 additions and 22 deletions

View File

@ -0,0 +1,6 @@
# Results
You will get a DER about 4.21%, which is reported in [1], Table 6, line "SOND Oracle Profile".
# Reference
[1] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, Zhihao Du, Shiliang Zhang,
Siqi Zheng, Zhijie Yan. EMNLP 2022.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
from funasr.bin.diar_inference_launch import inference_launch
import sys
def main():
diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pth"
output_dir = sys.argv[3] if len(sys.argv) > 3 else "./outputs"
data_path_and_name_and_type = [
("data/test_rmsil/feats.scp", "speech", "kaldi_ark"),
("data/test_rmsil/test_rmsil_tdnn6_xvec.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
num_workers=1
)
pipeline(data_path_and_name_and_type)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,132 @@
import os
from funasr.utils.job_runner import MultiProcessRunnerV3
import numpy as np
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
from collections import OrderedDict
from tqdm import tqdm
from scipy.ndimage import median_filter
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("label_txt", type=str)
parser.add_argument("map_scp", type=str)
parser.add_argument("out_rttm", type=str)
parser.add_argument("--n_spk", type=int, default=4)
parser.add_argument("--chunk_len", type=int, default=1600)
parser.add_argument("--shift_len", type=int, default=400)
parser.add_argument("--ignore_len", type=int, default=5)
parser.add_argument("--smooth_size", type=int, default=7)
parser.add_argument("--vote_prob", type=float, default=0.5)
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out_rttm)):
os.makedirs(os.path.dirname(args.out_rttm))
utt2labels = load_scp_as_list(args.label_txt, 'list')
utt2labels = sorted(utt2labels, key=lambda x: x[0])
meeting2map = load_scp_as_dict(args.map_scp)
meeting2labels = OrderedDict()
for utt_id, chunk_label in utt2labels:
mid = utt_id.split("-")[0]
if mid not in meeting2labels:
meeting2labels[mid] = []
meeting2labels[mid].append(chunk_label)
task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()]
return task_list, None, args
def post(self, result_list, args):
with open(args.out_rttm, "wt") as fd:
for results in result_list:
fd.writelines(results)
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def sample2ms(sample, sr=16000):
return int(float(sample) / sr * 100)
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
n_chunk = len(chunk_label_list)
last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
multi_labels = np.zeros((n_frame, n_spk), dtype=float)
weight = np.zeros((n_frame, 1), dtype=float)
for i in range(n_chunk):
raw_label = chunk_label_list[i]
for k in range(len(raw_label)):
if raw_label[k] == '<unk>':
raw_label[k] = raw_label[k-1] if k > 0 else '0'
chunk_multi_label = seq2arr(raw_label, n_spk)
chunk_len = chunk_multi_label.shape[0]
multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
weight[i*shift_len:i*shift_len+chunk_len, :] += 1
multi_labels = multi_labels / weight # normalizing vote
multi_labels = (multi_labels > vote_prob).astype(int) # voting results
return multi_labels
def calc_spk_turns(label_arr, spk_list):
turn_list = []
length = label_arr.shape[0]
n_spk = label_arr.shape[1]
for k in range(n_spk):
if spk_list[k] == "None":
continue
in_utt = False
start = 0
for i in range(length):
if label_arr[i, k] == 1 and in_utt is False:
start = i
in_utt = True
if label_arr[i, k] == 0 and in_utt is True:
turn_list.append([spk_list[k], start, i - start])
in_utt = False
if in_utt:
turn_list.append([spk_list[k], start, length - start])
return turn_list
def smooth_multi_labels(multi_label, win_len):
multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
return multi_label
def process(task_args):
_, task_list, _, args = task_args
spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
template = "SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
results = []
for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
utt2map = load_scp_as_list(map_file_path, 'list')
multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
org_len = sample2ms(int(utt2map[-1][1][1]), args.sr)
org_multi_labels = np.zeros((org_len, args.n_spk))
for seg_id, [org_st, org_ed, st, ed] in utt2map:
org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr)
st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr)
ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0])
org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :]
spk_turns = calc_spk_turns(org_multi_labels, spk_list)
spk_turns = sorted(spk_turns, key=lambda x: x[1])
for spk, st, dur in spk_turns:
# TODO: handle the leak of segments at the change points
if dur > args.ignore_len:
results.append(template.format(mid, float(st)/100, float(dur)/100, spk))
return results
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

View File

@ -0,0 +1,48 @@
#!/bin/bash
. ./path.sh || exit 1;
stage=0
stop_stage=2
. utils/parse_options.sh || exit 1;
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Downloading AliMeeting test set data..."
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/alimeeting_test_data_for_sond.tar.gz
echo "Done. Extracting data..."
tar zxf alimeeting_test_data_for_sond.tar.gz
echo "Done."
echo "Downloading Pre-trained model..."
git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pth
cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pth
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
echo "Done."
echo "Downloading dscore for scoring..."
git clone https://github.com/nryant/dscore.git
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Calculating diarization results..."
python infer_alimeeting_test.py sond_fbank.yaml sond.pth outputs
python local/convert_label_to_rttm.py \
outputs/labels.txt \
data/test_rmsil/raw_rmsil_map.scp \
outputs/prediction_sm_83.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 \
--vote_prob 0.5 --n_spk 16
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Scoring..."
python dscore/score.py \
-r data/test_rmsil/test_org.crttm \
-s outputs/prediction_sm_83.rttm \
--collar 0.25
fi

View File

@ -0,0 +1,97 @@
from funasr.bin.diar_inference_launch import inference_launch
import os
def test_fbank_cpu_infer():
diar_config_path = "config_fbank.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_fbank_gpu_infer():
diar_config_path = "config_fbank.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_wav_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
data_path_and_name_and_type = [
("data/unit_test/test_wav.scp", "speech", "sound"),
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
]
pipeline = inference_launch(
mode="sond",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
)
results = pipeline(data_path_and_name_and_type)
print(results)
def test_without_profile_gpu_infer():
diar_config_path = "config.yaml"
diar_model_path = "sond.pth"
output_dir = "./outputs"
raw_inputs = [[
"data/unit_test/raw_inputs/record.wav",
"data/unit_test/raw_inputs/spk1.wav",
"data/unit_test/raw_inputs/spk2.wav",
"data/unit_test/raw_inputs/spk3.wav",
"data/unit_test/raw_inputs/spk4.wav"
]]
pipeline = inference_launch(
mode="sond_demo",
diar_train_config=diar_config_path,
diar_model_file=diar_model_path,
output_dir=output_dir,
ngpu=1,
num_workers=1,
log_level="WARNING",
param_dict={},
)
results = pipeline(raw_inputs=raw_inputs)
print(results)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
test_fbank_cpu_infer()
test_fbank_gpu_infer()
test_wav_gpu_infer()
test_without_profile_gpu_infer()

View File

@ -0,0 +1,179 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--njob",
type=int,
default=1,
help="The number of jobs for each gpu",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--diar_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--diar_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global CMVN file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("The inference configuration related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument(
"--diar_smooth_size",
type=int,
default=121,
help="The smoothing size for post-processing"
)
return parser
def inference_launch(mode, **kwargs):
if mode == "sond":
from funasr.bin.sond_inference import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "sond_demo":
from funasr.bin.sond_inference import inference_modelscope
param_dict = {
"extract_profile": True,
"sv_train_config": "sv.yaml",
"sv_model_file": "sv.pth",
}
if "param_dict" in kwargs:
kwargs["param_dict"].update(param_dict)
else:
kwargs["param_dict"] = param_dict
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
parser.add_argument(
"--mode",
type=str,
default="sond",
help="The decoding mode",
)
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
# set logging messages
logging.basicConfig(
level=args.log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("Decoding args: {}".format(kwargs))
# gpu setting
if args.ngpu > 0:
jobid = int(args.output_dir.split(".")[-1])
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
inference_launch(**kwargs)
if __name__ == "__main__":
main()

544
funasr/bin/sond_inference.py Executable file
View File

@ -0,0 +1,544 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from collections import OrderedDict
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
from funasr.tasks.asr import ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from scipy.ndimage import median_filter
from funasr.utils.misc import statistic_model_parameters
class Speech2Diarization:
"""Speech2Xvector class
Examples:
>>> import soundfile
>>> import numpy as np
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}
"""
def __init__(
self,
diar_train_config: Union[Path, str] = None,
diar_model_file: Union[Path, str] = None,
device: str = "cpu",
batch_size: int = 1,
dtype: str = "float32",
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: float = 10,
):
assert check_argument_types()
# TODO: 1. Build Diarization model
diar_model, diar_train_args = DiarTask.build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
device=device
)
logging.info("diar_model: {}".format(diar_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
logging.info("diar_train_args: {}".format(diar_train_args))
diar_model.to(dtype=getattr(torch, dtype)).eval()
self.diar_model = diar_model
self.diar_train_args = diar_train_args
self.token_list = diar_train_args.token_list
self.smooth_size = smooth_size
self.dur_threshold = dur_threshold
self.device = device
self.dtype = dtype
def smooth_multi_labels(self, multi_label):
multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
return multi_label
@staticmethod
def calc_spk_turns(label_arr, spk_list):
turn_list = []
length = label_arr.shape[0]
n_spk = label_arr.shape[1]
for k in range(n_spk):
if spk_list[k] == "None":
continue
in_utt = False
start = 0
for i in range(length):
if label_arr[i, k] == 1 and in_utt is False:
start = i
in_utt = True
if label_arr[i, k] == 0 and in_utt is True:
turn_list.append([spk_list[k], start, i - start])
in_utt = False
if in_utt:
turn_list.append([spk_list[k], start, length - start])
return turn_list
@staticmethod
def seq2arr(seq, vec_dim=8):
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
# upsampling outputs to match inputs
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
logits_idx = F.upsample(
logits_idx.unsqueeze(1).float(),
size=(ut, ),
mode="nearest",
).squeeze(1).long()
logits_idx = logits_idx[0].tolist()
pse_labels = [self.token_list[x] for x in logits_idx]
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
multi_labels = self.smooth_multi_labels(multi_labels)
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
results = OrderedDict()
for spk, st, dur in spk_turns:
if spk not in results:
results[spk] = []
if dur > self.dur_threshold:
results[spk].append((st, st+dur))
# sort segments in start time ascending
for spk in results:
results[spk] = sorted(results[spk], key=lambda x: x[0])
return results, pse_labels
@torch.no_grad()
def __call__(
self,
speech: Union[torch.Tensor, np.ndarray],
profile: Union[torch.Tensor, np.ndarray],
):
"""Inference
Args:
speech: Input speech data
profile: Speaker profiles
Returns:
diarization results for each speaker
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
# data: (Nsamples,) -> (1, Nsamples)
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
# lengths: (1,)
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
batch = {"speech": speech, "speech_lengths": speech_lengths,
"profile": profile, "profile_lengths": profile_lengths}
# a. To device
batch = to_device(batch, device=self.device)
logits = self.diar_model.prediction_forward(**batch)
results, pse_labels = self.post_processing(logits, profile.shape[1])
return results, pse_labels
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Xvector instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Xvector: Speech2Xvector instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Diarization(**kwargs)
def inference_modelscope(
diar_train_config: str,
diar_model_file: str,
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 0,
seed: int = 0,
num_workers: int = 0,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: int = 10,
out_format: str = "vad",
param_dict: Optional[dict] = None,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2a. Build speech2xvec [Optional]
if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
sv_train_config = param_dict["sv_train_config"]
sv_model_file = param_dict["sv_model_file"]
from funasr.bin.sv_inference import Speech2Xvector
speech2xvector_kwargs = dict(
sv_train_config=sv_train_config,
sv_model_file=sv_model_file,
device=device,
dtype=dtype,
streaming=streaming,
embedding_node="resnet1_dense"
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
speech2xvector = Speech2Xvector.from_pretrained(
model_tag=model_tag,
**speech2xvector_kwargs,
)
speech2xvector.sv_model.eval()
# 2b. Build speech2diar
speech2diar_kwargs = dict(
diar_train_config=diar_train_config,
diar_model_file=diar_model_file,
device=device,
dtype=dtype,
streaming=streaming,
smooth_size=smooth_size,
dur_threshold=dur_threshold,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
speech2diar = Speech2Diarization.from_pretrained(
model_tag=model_tag,
**speech2diar_kwargs,
)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
rst = []
mid = uttid.rsplit("-", 1)[0]
for key in results:
results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
if out_format == "vad":
for spk, segs in results.items():
rst.append("{} {}".format(spk, segs))
else:
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
for spk, segs in results.items():
rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
return "\n".join(rst)
def _forward(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None,
output_dir_v2: Optional[str] = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, (list, tuple)):
assert all([len(example) >= 2 for example in raw_inputs]), \
"The length of test case in raw_inputs must larger than 1 (>=2)."
def prepare_dataset():
for idx, example in enumerate(raw_inputs):
# read waveform file
example = [soundfile.read(x)[0] if isinstance(example[0], str) else x
for x in example]
# convert torch tensor to numpy array
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
for x in example]
speech = example[0]
logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
profile = torch.cat(profile, dim=0)
yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
loader = prepare_dataset()
else:
raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
else:
# 3. Build data-iterator
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=None,
collate_fn=None,
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
# 7. Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
output_writer = open("{}/result.txt".format(output_path), "w")
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
logging.info("Start to diarize...")
result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
results, pse_labels = speech2diar(**batch)
# Only supporting batch_size==1
key, value = keys[0], output_results_str(results, keys[0])
item = {"key": key, "value": value}
result_list.append(item)
if output_path is not None:
output_writer.write(value)
output_writer.flush()
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
pse_label_writer.flush()
if output_path is not None:
output_writer.close()
pse_label_writer.close()
return result_list
return _forward
def inference(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
diar_train_config: Optional[str],
diar_model_file: Optional[str],
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 0,
seed: int = 0,
num_workers: int = 1,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
smooth_size: int = 83,
dur_threshold: int = 10,
out_format: str = "vad",
**kwargs,
):
inference_pipeline = inference_modelscope(
diar_train_config=diar_train_config,
diar_model_file=diar_model_file,
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
model_tag=model_tag,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
smooth_size=smooth_size,
dur_threshold=dur_threshold,
out_format=out_format,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker verification/x-vector extraction",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--diar_train_config",
type=str,
help="diarization training configuration",
)
group.add_argument(
"--diar_model_file",
type=str,
help="diarization model parameter file",
)
group.add_argument(
"--dur_threshold",
type=int,
default=10,
help="The threshold for short segments in number frames"
)
parser.add_argument(
"--smooth_size",
type=int,
default=83,
help="The smoothing window length in number frames"
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
parser.add_argument("--streaming", type=str2bool, default=False)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
logging.info("args: {}".format(kwargs))
if args.output_dir is None:
jobid, n_gpu = 1, 1
gpuid = args.gpuid_list.split(",")[jobid-1]
else:
jobid = int(args.output_dir.split(".")[-1])
n_gpu = len(args.gpuid_list.split(","))
gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
results_list = inference(**kwargs)
for results in results_list:
print("{} {}".format(results["key"], results["value"]))
if __name__ == "__main__":
main()

View File

@ -1,4 +1,7 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
@ -26,7 +29,7 @@ from funasr.utils import config_argparse
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
class Speech2Xvector:
"""Speech2Xvector class
@ -59,6 +62,7 @@ class Speech2Xvector:
device=device
)
logging.info("sv_model: {}".format(sv_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
logging.info("sv_train_args: {}".format(sv_train_args))
sv_model.to(dtype=getattr(torch, dtype)).eval()
@ -156,17 +160,17 @@ class Speech2Xvector:
def inference_modelscope(
output_dir: Optional[str],
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
key_file: Optional[str],
sv_train_config: Optional[str],
sv_model_file: Optional[str],
model_tag: Optional[str],
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
ngpu: int = 1,
seed: int = 0,
num_workers: int = 0,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
sv_train_config: Optional[str] = "sv.yaml",
sv_model_file: Optional[str] = "sv.pth",
model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
embedding_node: str = "resnet1_dense",
@ -214,7 +218,6 @@ def inference_modelscope(
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging

View File

@ -0,0 +1,402 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class DiarSondModel(AbsESPnetModel):
"""Speaker overlap-aware neural diarization model
reference: https://arxiv.org/abs/2211.10243
"""
def __init__(
self,
vocab_size: int,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
encoder: AbsEncoder,
speaker_encoder: AbsEncoder,
ci_scorer: torch.nn.Module,
cd_scorer: torch.nn.Module,
decoder: torch.nn.Module,
token_list: list,
lsm_weight: float = 0.1,
length_normalized_loss: bool = False,
max_spk_num: int = 16,
label_aggregator: Optional[torch.nn.Module] = None,
normlize_speech_speaker: bool = False,
):
assert check_argument_types()
super().__init__()
self.encoder = encoder
self.speaker_encoder = speaker_encoder
self.ci_scorer = ci_scorer
self.cd_scorer = cd_scorer
self.normalize = normalize
self.frontend = frontend
self.specaug = specaug
self.label_aggregator = label_aggregator
self.decoder = decoder
self.token_list = token_list
self.max_spk_num = max_spk_num
self.normalize_speech_speaker = normlize_speech_speaker
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor = None,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
Args:
speech: (Batch, samples)
speech_lengths: (Batch,) default None for chunk interator,
because the chunk-iterator does not
have the speech_lengths returned.
see in
espnet2/iterators/chunk_iter_factory.py
profile: (Batch, N_spk, dim)
profile_lengths: (Batch,)
spk_labels: (Batch, )
"""
assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if self.attractor is None:
# 2a. Decoder (baiscally a predction layer after encoder_out)
pred = self.decoder(encoder_out, encoder_out_lens)
else:
# 2b. Encoder Decoder Attractors
# Shuffle the chronological order of encoder_out, then calculate attractor
encoder_out_shuffled = encoder_out.clone()
for i in range(len(encoder_out_lens)):
encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
i, torch.randperm(encoder_out_lens[i]), :
]
attractor, att_prob = self.attractor(
encoder_out_shuffled,
encoder_out_lens,
to_device(
self,
torch.zeros(
encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
),
),
)
# Remove the final attractor which does not correspond to a speaker
# Then multiply the attractors and encoder_out
pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
# 3. Aggregate time-domain labels
if self.label_aggregator is not None:
spk_labels, spk_labels_lengths = self.label_aggregator(
spk_labels, spk_labels_lengths
)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slighly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
length_diff = spk_labels.shape[1] - pred.shape[1]
if length_diff > 0 and length_diff <= length_diff_tolerance:
spk_labels = spk_labels[:, 0 : pred.shape[1], :]
if self.attractor is None:
loss_pit, loss_att = None, None
loss, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
else:
loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
pred, spk_labels, encoder_out_lens
)
loss_att = self.attractor_loss(att_prob, spk_labels)
loss = loss_pit + self.attractor_weight * loss_att
(
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
if speech_scored > 0 and num_frames > 0:
sad_mr, sad_fr, mi, fa, cf, acc, der = (
speech_miss / speech_scored,
speech_falarm / speech_scored,
speaker_miss / speaker_scored,
speaker_falarm / speaker_scored,
speaker_error / speaker_scored,
correct / num_frames,
(speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
)
else:
sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
stats = dict(
loss=loss.detach(),
loss_att=loss_att.detach() if loss_att is not None else None,
loss_pit=loss_pit.detach() if loss_pit is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
fa=fa,
cf=cf,
acc=acc,
der=der,
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
spk_labels: torch.Tensor = None,
spk_labels_lengths: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode_speaker(
self,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
with autocast(False):
if profile.shape[1] < self.max_spk_num:
profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0)
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()
profile = F.normalize(profile, dim=2)
if self.speaker_encoder is not None:
profile = self.speaker_encoder(profile, profile_lengths)[0]
return profile * profile_mask, profile_lengths
else:
return profile, profile_lengths
def encode_speech(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.encoder is not None:
speech, speech_lengths = self.encode(speech, speech_lengths)
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
return speech * speech_mask, speech_lengths
else:
return speech, speech_lengths
@staticmethod
def concate_speech_ivc(
speech: torch.Tensor,
ivc: torch.Tensor
) -> torch.Tensor:
nn, tt = ivc.shape[1], speech.shape[1]
speech = speech.unsqueeze(dim=1) # B x 1 x T x D
speech = speech.expand(-1, nn, -1, -1) # B x N x T x D
ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D
ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D
sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D
return sd_in
def calc_similarity(
self,
speech_encoder_outputs: torch.Tensor,
speaker_encoder_outputs: torch.Tensor,
seq_len: torch.Tensor = None,
spk_len: torch.Tensor = None,
) -> torch.Tensor:
bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
if self.normalize_speech_speaker:
speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2)
speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2)
ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs)
ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk])
ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num)
ge_len = torch.reshape(ge_len, [bb * self.max_spk_num])
cd_simi = self.cd_scorer(ge_in, ge_len)[0]
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
if isinstance(self.ci_scorer, AbsEncoder):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
else:
ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
simi = torch.cat([cd_simi, ci_simi], dim=2)
return simi
def post_net_forward(self, simi, seq_len):
logits = self.decoder(simi, seq_len)[0]
return logits
def prediction_forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
) -> torch.Tensor:
# speech encoding
speech, speech_lengths = self.encode_speech(speech, speech_lengths)
# speaker encoding
profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
# calculating similarity
similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
# post net forward
logits = self.post_net_forward(similarity, speech_lengths)
return logits
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = speech.shape[0]
speech_lengths = (
speech_lengths
if speech_lengths is not None
else torch.ones(batch_size).int() * speech.shape[1]
)
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
@staticmethod
def calc_diarization_error(pred, label, length):
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
(batch_size, max_len, num_output) = label.size()
# mask the padding part
mask = np.zeros((batch_size, max_len, num_output))
for i in range(batch_size):
mask[i, : length[i], :] = 1
# pred and label have the shape (batch_size, max_len, num_output)
label_np = label.data.cpu().numpy().astype(int)
pred_np = (pred.data.cpu().numpy() > 0).astype(int)
label_np = label_np * mask
pred_np = pred_np * mask
length = length.data.cpu().numpy()
# compute speech activity detection error
n_ref = np.sum(label_np, axis=2)
n_sys = np.sum(pred_np, axis=2)
speech_scored = float(np.sum(n_ref > 0))
speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
# compute speaker diarization error
speaker_scored = float(np.sum(n_ref))
speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
num_frames = np.sum(length)
return (
correct,
num_frames,
speech_scored,
speech_miss,
speech_falarm,
speaker_scored,
speaker_miss,
speaker_falarm,
speaker_error,
)

View File

@ -0,0 +1,38 @@
import torch
from torch.nn import functional as F
class DotScorer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
xs_pad: torch.Tensor,
spk_emb: torch.Tensor,
):
# xs_pad: B, T, D
# spk_emb: B, N, D
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
return scores
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
return {}
class CosScorer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
xs_pad: torch.Tensor,
spk_emb: torch.Tensor,
):
# xs_pad: B, T, D
# spk_emb: B, N, D
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
return scores
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
return {}

View File

@ -0,0 +1,277 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.layer_norm import LayerNorm
from funasr.models.encoder.abs_encoder import AbsEncoder
import math
from funasr.modules.repeat import repeat
class EncoderLayer(nn.Module):
def __init__(
self,
input_units,
num_units,
kernel_size=3,
activation="tanh",
stride=1,
include_batch_norm=False,
residual=False
):
super().__init__()
left_padding = math.ceil((kernel_size - stride) / 2)
right_padding = kernel_size - stride - left_padding
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv1d = nn.Conv1d(
input_units,
num_units,
kernel_size,
stride,
)
self.activation = self.get_activation(activation)
if include_batch_norm:
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
self.residual = residual
self.include_batch_norm = include_batch_norm
self.input_units = input_units
self.num_units = num_units
self.stride = stride
@staticmethod
def get_activation(activation):
if activation == "tanh":
return nn.Tanh()
else:
return nn.ReLU()
def forward(self, xs_pad, ilens=None):
outputs = self.conv1d(self.conv_padding(xs_pad))
if self.residual and self.stride == 1 and self.input_units == self.num_units:
outputs = outputs + xs_pad
if self.include_batch_norm:
outputs = self.bn(outputs)
# add parenthesis for repeat module
return self.activation(outputs), ilens
class ConvEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
Convolution encoder in OpenNMT framework
"""
def __init__(
self,
num_layers,
input_units,
num_units,
kernel_size=3,
dropout_rate=0.3,
position_encoder=None,
activation='tanh',
auxiliary_states=True,
out_units=None,
out_norm=False,
out_residual=False,
include_batchnorm=False,
regularization_weight=0.0,
stride=1,
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
):
assert check_argument_types()
super().__init__()
self._output_size = num_units
self.num_layers = num_layers
self.input_units = input_units
self.num_units = num_units
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.position_encoder = position_encoder
self.out_units = out_units
self.auxiliary_states = auxiliary_states
self.out_norm = out_norm
self.activation = activation
self.out_residual = out_residual
self.include_batch_norm = include_batchnorm
self.regularization_weight = regularization_weight
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
if isinstance(stride, int):
self.stride = [stride] * self.num_layers
else:
self.stride = stride
self.downsample_rate = 1
for s in self.stride:
self.downsample_rate *= s
self.dropout = nn.Dropout(dropout_rate)
self.cnn_a = repeat(
self.num_layers,
lambda lnum: EncoderLayer(
input_units if lnum == 0 else num_units,
num_units,
kernel_size,
activation,
self.stride[lnum],
include_batchnorm,
residual=True if lnum > 0 else False
)
)
if self.out_units is not None:
left_padding = math.ceil((kernel_size - stride) / 2)
right_padding = kernel_size - stride - left_padding
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv_out = nn.Conv1d(
num_units,
num_units,
kernel_size,
)
if self.out_norm:
self.after_norm = LayerNorm(num_units)
def output_size(self) -> int:
return self.num_units
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
if self.dropout_rate > 0:
inputs = self.dropout(inputs)
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
if self.out_units is not None:
outputs = self.conv_out(self.out_padding(outputs))
outputs = outputs.transpose(1, 2)
if self.out_norm:
outputs = self.after_norm(outputs)
if self.out_residual:
outputs = outputs + inputs
return outputs, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
}
if self.out_units is not None:
# add output layer
map_dict_local.update({
"{}.conv_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
"squeeze": None,
"transpose": (2, 1, 0),
}, # tf: (1, 256, 256) -> torch: (256, 256, 1)
"{}.conv_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
"squeeze": None,
"transpose": None,
}, # tf: (256,) -> torch: (256,)
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@ -0,0 +1,335 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.layer_norm import LayerNorm
from funasr.models.encoder.abs_encoder import AbsEncoder
import math
from funasr.modules.repeat import repeat
from funasr.modules.multi_layer_conv import FsmnFeedForward
class FsmnBlock(torch.nn.Module):
def __init__(
self,
n_feat,
dropout_rate,
kernel_size,
fsmn_shift=0,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
padding=0, groups=n_feat, bias=False)
# padding
left_padding = (kernel_size - 1) // 2
if fsmn_shift > 0:
left_padding = left_padding + fsmn_shift
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = self.dropout(x)
return x * mask
class EncoderLayer(torch.nn.Module):
def __init__(
self,
in_size,
size,
feed_forward,
fsmn_block,
dropout_rate=0.0
):
super().__init__()
self.in_size = in_size
self.size = size
self.ffn = feed_forward
self.memory = fsmn_block
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
xs_pad: torch.Tensor,
mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# xs_pad in Batch, Time, Dim
context = self.ffn(xs_pad)[0]
memory = self.memory(context, mask)
memory = self.dropout(memory)
if self.in_size == self.size:
return memory + xs_pad, mask
return memory, mask
class FsmnEncoder(AbsEncoder):
"""Encoder using Fsmn
"""
def __init__(self,
in_units,
filter_size,
fsmn_num_layers,
dnn_num_layers,
num_memory_units=512,
ffn_inner_dim=2048,
dropout_rate=0.0,
shift=0,
position_encoder=None,
sample_rate=1,
out_units=None,
tf2torch_tensor_name_prefix_torch="post_net",
tf2torch_tensor_name_prefix_tf="EAND/post_net"
):
"""Initializes the parameters of the encoder.
Args:
filter_size: the total order of memory block
fsmn_num_layers: The number of fsmn layers.
dnn_num_layers: The number of dnn layers
num_units: The number of memory units.
ffn_inner_dim: The number of units of the inner linear transformation
in the feed forward layer.
dropout_rate: The probability to drop units from the outputs.
shift: left padding, to control delay
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
apply on inputs or ``None``.
"""
super(FsmnEncoder, self).__init__()
self.in_units = in_units
self.filter_size = filter_size
self.fsmn_num_layers = fsmn_num_layers
self.dnn_num_layers = dnn_num_layers
self.num_memory_units = num_memory_units
self.ffn_inner_dim = ffn_inner_dim
self.dropout_rate = dropout_rate
self.shift = shift
if not isinstance(shift, list):
self.shift = [shift for _ in range(self.fsmn_num_layers)]
self.sample_rate = sample_rate
if not isinstance(sample_rate, list):
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
self.position_encoder = position_encoder
self.dropout = nn.Dropout(dropout_rate)
self.out_units = out_units
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.fsmn_layers = repeat(
self.fsmn_num_layers,
lambda lnum: EncoderLayer(
in_units if lnum == 0 else num_memory_units,
num_memory_units,
FsmnFeedForward(
in_units if lnum == 0 else num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate
),
FsmnBlock(
num_memory_units,
dropout_rate,
filter_size,
self.shift[lnum]
)
),
)
self.dnn_layers = repeat(
dnn_num_layers,
lambda lnum: FsmnFeedForward(
num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate,
)
)
if out_units is not None:
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
def output_size(self) -> int:
return self.num_memory_units
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
inputs = self.dropout(inputs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
inputs = self.fsmn_layers(inputs, masks)[0]
inputs = self.dnn_layers(inputs)[0]
if self.out_units is not None:
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
return inputs, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
# for fsmn_layers
"{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
{"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
# for dnn_layers
"{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
}
if self.out_units is not None:
# add output layer
map_dict_local.update({
"{}.conv1d.weight".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
},
"{}.conv1d.bias".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@ -0,0 +1,480 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
class EncoderLayer(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_att_chunk_encoder
class SelfAttentionEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
Self attention encoder in OpenNMT framework
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
out_units=None,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
elif input_layer == "null":
self.embed = None
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
output_size,
MultiHeadSelfAttention(
attention_heads,
output_size,
output_size,
attention_dropout_rate,
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
) if lnum > 0 else EncoderLayer(
input_size,
output_size,
MultiHeadSelfAttention(
attention_heads,
input_size if input_layer == "pe" or input_layer == "null" else output_size,
output_size,
attention_dropout_rate,
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.out_units = out_units
if out_units is not None:
self.output_linear = nn.Linear(output_size, out_units)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad = self.dropout(xs_pad)
# encoder_outs = self.encoders0(xs_pad, masks)
# xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.out_units is not None:
xs_pad = self.output_linear(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
# cicd
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (768,256),(1,256,768)
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (768,),(768,)
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# ffn
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
}
if self.out_units is not None:
map_dict_local.update({
"{}.output_linear.weight".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
},
"{}.output_linear.bias".format(tensor_name_prefix_torch):
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
# process special (first and last) layers
if name in map_dict:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
# process general layers
else:
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@ -1,7 +1,11 @@
import torch
from torch.nn import functional as F
from funasr.models.encoder.abs_encoder import AbsEncoder
from typing import Tuple
from typing import Tuple, Optional
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
from collections import OrderedDict
import logging
import numpy as np
class BasicLayer(torch.nn.Module):
@ -116,10 +120,18 @@ class ResNet34(AbsEncoder):
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
self.time_ds_ratio = 8
def output_size(self) -> int:
return self.num_nodes_pooling_layer
def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
features = xs_pad
assert features.size(-1) == self.input_size, \
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
@ -141,4 +153,463 @@ class ResNet34(AbsEncoder):
features = F.relu(features)
features = self.resnet0_bn(features)
return features, ilens // 8
return features, resnet_out_lens
# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
class ResNet34_SP_L2Reg(AbsEncoder):
def __init__(
self,
input_size,
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
tf_train_steps=720000,
):
super(ResNet34_SP_L2Reg, self).__init__()
self.use_head_conv = use_head_conv
self.use_head_maxpool = use_head_maxpool
self.num_nodes_pooling_layer = num_nodes_pooling_layer
self.layers_in_block = layers_in_block
self.filters_in_block = filters_in_block
self.input_size = input_size
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tf_train_steps = tf_train_steps
pre_filters = filters_in_block[0]
if use_head_conv:
self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
if use_head_maxpool:
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
for i in range(len(layers_in_block)):
if i == 0:
in_filters = pre_filters if self.use_head_conv else 1
else:
in_filters = filters_in_block[i-1]
block = BasicBlock(in_filters,
filters=filters_in_block[i],
num_layer=layers_in_block[i],
stride=1 if i == 0 else 2,
bn_momentum=batchnorm_momentum)
self.add_module("block_{}".format(i), block)
self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1)
self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
self.time_ds_ratio = 8
def output_size(self) -> int:
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
features = xs_pad
assert features.size(-1) == self.input_size, \
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
features = torch.unsqueeze(features, dim=1)
if self.use_head_conv:
features = self.pre_conv(features)
features = self.pre_conv_bn(features)
features = F.relu(features)
if self.use_head_maxpool:
features = self.head_maxpool(features)
resnet_outs, resnet_out_lens = features, ilens
for i in range(len(self.layers_in_block)):
block = self._modules["block_{}".format(i)]
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
# B, C, T, F
bb, cc, tt, ff = resnet_outs.shape
resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt])
features = self.resnet0_dense(resnet_outs)
features = F.relu(features)
features = self.resnet0_bn(features)
return features, resnet_out_lens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
train_steps = self.tf_train_steps
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
}
for layer_idx in range(3):
map_dict_local.update({
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
},
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
})
for block_idx in range(len(self.layers_in_block)):
for layer_idx in range(self.layers_in_block[block_idx]):
for i in ["1", "2", "_sc"]:
map_dict_local.update({
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
if name in map_dict:
if "num_batches_tracked" not in name:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
logging.info("torch tensor: {}, manually assigning to: {}".format(
name, map_dict[name]
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update
class ResNet34Diar(ResNet34):
def __init__(
self,
input_size,
embedding_node="resnet1_dense",
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
num_nodes_resnet1=256,
num_nodes_last_layer=256,
pooling_type="window_shift",
pool_size=20,
stride=1,
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
super(ResNet34Diar, self).__init__(
input_size,
use_head_conv=use_head_conv,
batchnorm_momentum=batchnorm_momentum,
use_head_maxpool=use_head_maxpool,
num_nodes_pooling_layer=num_nodes_pooling_layer,
layers_in_block=layers_in_block,
filters_in_block=filters_in_block,
)
self.embedding_node = embedding_node
self.num_nodes_resnet1 = num_nodes_resnet1
self.num_nodes_last_layer = num_nodes_last_layer
self.pooling_type = pooling_type
self.pool_size = pool_size
self.stride = stride
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
def output_size(self) -> int:
if self.embedding_node.startswith("resnet1"):
return self.num_nodes_resnet1
elif self.embedding_node.startswith("resnet2"):
return self.num_nodes_last_layer
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
endpoints = OrderedDict()
res_out, ilens = super().forward(xs_pad, ilens)
endpoints["resnet0_bn"] = res_out
if self.pooling_type == "frame_gsp":
features = statistic_pooling(res_out, ilens, (3, ))
else:
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride)
features = features.transpose(1, 2)
endpoints["pooling"] = features
features = self.resnet1_dense(features)
endpoints["resnet1_dense"] = features
features = F.relu(features)
endpoints["resnet1_relu"] = features
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet1_bn"] = features
features = self.resnet2_dense(features)
endpoints["resnet2_dense"] = features
features = F.relu(features)
endpoints["resnet2_relu"] = features
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
train_steps = 300000
map_dict_local = {
# torch: conv1d.weight in "out_channel in_channel kernel_size"
# tf : conv1d.weight in "kernel_size in_channel out_channel"
# torch: linear.weight in "out_channel in_channel"
# tf : dense.weight in "in_channel out_channel"
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
},
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
}
for layer_idx in range(3):
map_dict_local.update({
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
},
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
"squeeze": None,
"transpose": None,
},
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
})
for block_idx in range(len(self.layers_in_block)):
for layer_idx in range(self.layers_in_block[block_idx]):
for i in ["1", "2", "_sc"]:
map_dict_local.update({
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": (3, 2, 0, 1),
},
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
"squeeze": None,
"transpose": None,
},
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
})
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
for name in sorted(var_dict_torch.keys(), reverse=False):
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
if name in map_dict:
if "num_batches_tracked" not in name:
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), \
"{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[name].size(), data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
))
else:
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
logging.info("torch tensor: {}, manually assigning to: {}".format(
name, map_dict[name]
))
else:
logging.warning("{} is missed from tf checkpoint".format(name))
return var_dict_torch_update

View File

@ -90,7 +90,9 @@ class WavFrontend(AbsFrontend):
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@ -105,6 +107,8 @@ class WavFrontend(AbsFrontend):
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@ -119,7 +123,8 @@ class WavFrontend(AbsFrontend):
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
@ -128,7 +133,8 @@ class WavFrontend(AbsFrontend):
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
sample_frequency=self.fs,
snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)

View File

@ -2,7 +2,10 @@ import torch
from typing import Tuple
from typing import Union
from funasr.modules.nets_utils import make_non_pad_mask
from torch.nn import functional as F
import math
VAR2STD_EPSILON = 1e-12
class StatisticPooling(torch.nn.Module):
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
@ -34,3 +37,59 @@ class StatisticPooling(torch.nn.Module):
stat_pooling = torch.cat([mean, stddev], dim=1)
return stat_pooling
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
return {}
def statistic_pooling(
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
pooling_dim: Tuple = (2, 3)
) -> torch.Tensor:
# xs_pad in (Batch, Channel, Time, Frequency)
if ilens is None:
seq_mask = torch.ones_like(xs_pad).to(xs_pad)
else:
seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
squared_difference = torch.pow(xs_pad - mean, 2.0)
variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
for i in reversed(pooling_dim):
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float()
variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON
stddev = torch.sqrt(variance)
stat_pooling = torch.cat([mean, stddev], dim=1)
return stat_pooling
def windowed_statistic_pooling(
xs_pad: torch.Tensor,
ilens: torch.Tensor = None,
pooling_dim: Tuple = (2, 3),
pooling_size: int = 20,
pooling_stride: int = 1
) -> Tuple[torch.Tensor, int]:
# xs_pad in (Batch, Channel, Time, Frequency)
tt = xs_pad.shape[2]
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
stat_list = []
for i in range(num_chunk):
# B x C
st, ed = i*pooling_stride, i*pooling_stride+pooling_size
stat = statistic_pooling(features[:, :, st: ed, :], pooling_dim=pooling_dim)
stat_list.append(stat.unsqueeze(2))
# B x C x T
return torch.cat(stat_list, dim=2), ilens / pooling_stride

View File

@ -622,4 +622,108 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
q_h, k_h, v_h = self.forward_qkv(x, memory)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
return self.forward_attention(v_h, scores, memory_mask)
return self.forward_attention(v_h, scores, memory_mask)
class MultiHeadSelfAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, in_feat, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadSelfAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(self, x):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
b, t, d = x.size()
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
return q_h, k_h, v_h, v
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
if mask_att_chunk_encoder is not None:
mask = mask * mask_att_chunk_encoder
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = float(
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, x, mask, mask_att_chunk_encoder=None):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q_h, k_h, v_h, v = self.forward_qkv(x)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs

View File

@ -63,6 +63,58 @@ class MultiLayeredConv1d(torch.nn.Module):
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
class FsmnFeedForward(torch.nn.Module):
"""Position-wise feed forward for FSMN blocks.
This is a module of multi-leyered conv1d designed
to replace position-wise feed-forward network
in FSMN block.
"""
def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
"""Initialize FsmnFeedForward module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
out_chans (int): Number of output channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(FsmnFeedForward, self).__init__()
self.w_1 = torch.nn.Conv1d(
in_chans,
hidden_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.w_2 = torch.nn.Conv1d(
hidden_chans,
out_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
bias=False
)
self.norm = torch.nn.LayerNorm(hidden_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x, ilens=None):
"""Calculate forward propagation.
Args:
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
Returns:
torch.Tensor: Batch of output tensors (B, T, out_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.

585
funasr/tasks/diar.py Normal file
View File

@ -0,0 +1,585 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
specaug_lfr=SpecAugLFR,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
label_aggregator=LabelAggregate
),
type_check=torch.nn.Module,
default=None,
optional=True,
)
model_choices = ClassChoices(
"model",
classes=dict(
sond=DiarSondModel,
),
type_check=AbsESPnetModel,
default="sond",
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
san=SelfAttentionEncoder,
fsmn=FsmnEncoder,
conv=ConvEncoder,
resnet34=ResNet34Diar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
),
type_check=AbsEncoder,
default="resnet34",
)
speaker_encoder_choices = ClassChoices(
"speaker_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
san=SelfAttentionEncoder,
fsmn=FsmnEncoder,
conv=ConvEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
),
type_check=AbsEncoder,
default=None,
optional=True
)
cd_scorer_choices = ClassChoices(
"cd_scorer",
classes=dict(
san=SelfAttentionEncoder,
),
type_check=AbsEncoder,
default=None,
optional=True,
)
ci_scorer_choices = ClassChoices(
"ci_scorer",
classes=dict(
dot=DotScorer,
cosine=CosScorer,
),
type_check=torch.nn.Module,
default=None,
optional=True,
)
# decoder is used for output (e.g. post_net in SOND)
decoder_choices = ClassChoices(
"decoder",
classes=dict(
rnn=RNNEncoder,
fsmn=FsmnEncoder,
),
type_check=torch.nn.Module,
default="fsmn",
)
class DiarTask(AbsTask):
# If you need more than 1 optimizer, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --encoder and --encoder_conf
encoder_choices,
# --speaker_encoder and --speaker_encoder_conf
speaker_encoder_choices,
# --cd_scorer and cd_scorer_conf
cd_scorer_choices,
# --ci_scorer and ci_scorer_conf
ci_scorer_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="char",
choices=["char"],
help="The text will be tokenized in the specified level token",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=None,
non_linguistic_symbols=None,
text_cleaner=None,
g2p_type=None,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "profile", "label")
else:
# Recognition mode
retval = ("speech", "profile")
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 5. speaker encoder
if getattr(args, "speaker_encoder", None) is not None:
speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
else:
speaker_encoder = None
# 6. CI & CD scorer
if getattr(args, "ci_scorer", None) is not None:
ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
else:
ci_scorer = None
if getattr(args, "cd_scorer", None) is not None:
cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
else:
cd_scorer = None
# 7. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
decoder = decoder_class(**args.decoder_conf)
if getattr(args, "label_aggregator", None) is not None:
label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
else:
label_aggregator = None
# 9. Build model
model_class = model_choices.get_class(args.model)
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,
speaker_encoder=speaker_encoder,
ci_scorer=ci_scorer,
cd_scorer=cd_scorer,
decoder=decoder,
token_list=token_list,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
model_name_pth = None
if model_file is not None:
logging.info("model_file is {}".format(model_file))
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
model_dir = os.path.dirname(model_file)
model_name = os.path.basename(model_file)
if "model.ckpt-" in model_name or ".bin" in model_name:
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args
@classmethod
def convert_tf2torch(
cls,
model,
ckpt,
):
logging.info("start convert tf model to torch model")
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# speaker encoder
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# cd scorer
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# ci scorer
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update

103
funasr/utils/job_runner.py Normal file
View File

@ -0,0 +1,103 @@
from __future__ import print_function
from multiprocessing import Pool
import argparse
from tqdm import tqdm
import math
class MultiProcessRunner:
def __init__(self, fn):
self.args = None
self.process = fn
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--nj", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_ture", default=False)
task_list, args = self.prepare(parser)
result_list = self.pool_run(task_list, args)
self.post(result_list, args)
def prepare(self, parser):
raise NotImplementedError("Please implement the prepare function.")
def post(self, result_list, args):
raise NotImplementedError("Please implement the post function.")
def pool_run(self, tasks, args):
results = []
if args.debug:
one_result = self.process(tasks[0])
results.append(one_result)
else:
pool = Pool(args.nj)
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
results.append(one_result)
pool.close()
return results
class MultiProcessRunnerV2:
def __init__(self, fn):
self.args = None
self.process = fn
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--nj", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
task_list, args = self.prepare(parser)
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
if args.verbose:
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)]
result_list = self.pool_run(subtask_list, args)
self.post(result_list, args)
def prepare(self, parser):
raise NotImplementedError("Please implement the prepare function.")
def post(self, result_list, args):
raise NotImplementedError("Please implement the post function.")
def pool_run(self, tasks, args):
results = []
if args.debug:
one_result = self.process(tasks[0])
results.append(one_result)
else:
pool = Pool(args.nj)
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
results.append(one_result)
pool.close()
return results
class MultiProcessRunnerV3(MultiProcessRunnerV2):
def run(self):
parser = argparse.ArgumentParser("")
# Task-independent options
parser.add_argument("--nj", type=int, default=16)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--no_pbar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--sr", type=int, default=16000)
task_list, shared_param, args = self.prepare(parser)
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
if args.verbose:
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
for i in range(args.nj)]
result_list = self.pool_run(subtask_list, args)
self.post(result_list, args)

48
funasr/utils/misc.py Normal file
View File

@ -0,0 +1,48 @@
import io
from collections import OrderedDict
import numpy as np
def statistic_model_parameters(model, prefix=None):
var_dict = model.state_dict()
numel = 0
for i, key in enumerate(sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))):
if prefix is None or key.startswith(prefix):
numel += var_dict[key].numel()
return numel
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def load_scp_as_dict(scp_path, value_type='str', kv_sep=" "):
with io.open(scp_path, 'r', encoding='utf-8') as f:
ret_dict = OrderedDict()
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1:]
if value_type == 'list':
value = value.split(' ')
ret_dict[key] = value
return ret_dict
def load_scp_as_list(scp_path, value_type='str', kv_sep=" "):
with io.open(scp_path, 'r', encoding='utf8') as f:
ret_dict = []
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1:]
if value_type == 'list':
value = value.split(' ')
ret_dict.append((key, value))
return ret_dict