mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #95 from alibaba-damo-academy/dev_dzh
Add sound model
This commit is contained in:
commit
60aef2aa96
6
egs/alimeeting/diarization/sond/README.md
Normal file
6
egs/alimeeting/diarization/sond/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
# Results
|
||||
You will get a DER about 4.21%, which is reported in [1], Table 6, line "SOND Oracle Profile".
|
||||
|
||||
# Reference
|
||||
[1] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, Zhihao Du, Shiliang Zhang,
|
||||
Siqi Zheng, Zhijie Yan. EMNLP 2022.
|
||||
2740
egs/alimeeting/diarization/sond/config.yaml
Normal file
2740
egs/alimeeting/diarization/sond/config.yaml
Normal file
File diff suppressed because it is too large
Load Diff
2728
egs/alimeeting/diarization/sond/config_fbank.yaml
Normal file
2728
egs/alimeeting/diarization/sond/config_fbank.yaml
Normal file
File diff suppressed because it is too large
Load Diff
24
egs/alimeeting/diarization/sond/infer_alimeeting_test.py
Normal file
24
egs/alimeeting/diarization/sond/infer_alimeeting_test.py
Normal file
@ -0,0 +1,24 @@
|
||||
from funasr.bin.diar_inference_launch import inference_launch
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
|
||||
diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pth"
|
||||
output_dir = sys.argv[3] if len(sys.argv) > 3 else "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/test_rmsil/feats.scp", "speech", "kaldi_ark"),
|
||||
("data/test_rmsil/test_rmsil_tdnn6_xvec.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
num_workers=1
|
||||
)
|
||||
pipeline(data_path_and_name_and_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
132
egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py
Normal file
132
egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py
Normal file
@ -0,0 +1,132 @@
|
||||
import os
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
import numpy as np
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
from collections import OrderedDict
|
||||
from tqdm import tqdm
|
||||
from scipy.ndimage import median_filter
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("label_txt", type=str)
|
||||
parser.add_argument("map_scp", type=str)
|
||||
parser.add_argument("out_rttm", type=str)
|
||||
parser.add_argument("--n_spk", type=int, default=4)
|
||||
parser.add_argument("--chunk_len", type=int, default=1600)
|
||||
parser.add_argument("--shift_len", type=int, default=400)
|
||||
parser.add_argument("--ignore_len", type=int, default=5)
|
||||
parser.add_argument("--smooth_size", type=int, default=7)
|
||||
parser.add_argument("--vote_prob", type=float, default=0.5)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out_rttm)):
|
||||
os.makedirs(os.path.dirname(args.out_rttm))
|
||||
|
||||
utt2labels = load_scp_as_list(args.label_txt, 'list')
|
||||
utt2labels = sorted(utt2labels, key=lambda x: x[0])
|
||||
meeting2map = load_scp_as_dict(args.map_scp)
|
||||
meeting2labels = OrderedDict()
|
||||
for utt_id, chunk_label in utt2labels:
|
||||
mid = utt_id.split("-")[0]
|
||||
if mid not in meeting2labels:
|
||||
meeting2labels[mid] = []
|
||||
meeting2labels[mid].append(chunk_label)
|
||||
task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()]
|
||||
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
with open(args.out_rttm, "wt") as fd:
|
||||
for results in result_list:
|
||||
fd.writelines(results)
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.int):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
def sample2ms(sample, sr=16000):
|
||||
return int(float(sample) / sr * 100)
|
||||
|
||||
|
||||
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
|
||||
n_chunk = len(chunk_label_list)
|
||||
last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
|
||||
n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
|
||||
multi_labels = np.zeros((n_frame, n_spk), dtype=float)
|
||||
weight = np.zeros((n_frame, 1), dtype=float)
|
||||
for i in range(n_chunk):
|
||||
raw_label = chunk_label_list[i]
|
||||
for k in range(len(raw_label)):
|
||||
if raw_label[k] == '<unk>':
|
||||
raw_label[k] = raw_label[k-1] if k > 0 else '0'
|
||||
chunk_multi_label = seq2arr(raw_label, n_spk)
|
||||
chunk_len = chunk_multi_label.shape[0]
|
||||
multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
|
||||
weight[i*shift_len:i*shift_len+chunk_len, :] += 1
|
||||
multi_labels = multi_labels / weight # normalizing vote
|
||||
multi_labels = (multi_labels > vote_prob).astype(int) # voting results
|
||||
return multi_labels
|
||||
|
||||
|
||||
def calc_spk_turns(label_arr, spk_list):
|
||||
turn_list = []
|
||||
length = label_arr.shape[0]
|
||||
n_spk = label_arr.shape[1]
|
||||
for k in range(n_spk):
|
||||
if spk_list[k] == "None":
|
||||
continue
|
||||
in_utt = False
|
||||
start = 0
|
||||
for i in range(length):
|
||||
if label_arr[i, k] == 1 and in_utt is False:
|
||||
start = i
|
||||
in_utt = True
|
||||
if label_arr[i, k] == 0 and in_utt is True:
|
||||
turn_list.append([spk_list[k], start, i - start])
|
||||
in_utt = False
|
||||
if in_utt:
|
||||
turn_list.append([spk_list[k], start, length - start])
|
||||
return turn_list
|
||||
|
||||
|
||||
def smooth_multi_labels(multi_label, win_len):
|
||||
multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
|
||||
return multi_label
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
|
||||
template = "SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
|
||||
results = []
|
||||
for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
|
||||
utt2map = load_scp_as_list(map_file_path, 'list')
|
||||
multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
|
||||
multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
|
||||
org_len = sample2ms(int(utt2map[-1][1][1]), args.sr)
|
||||
org_multi_labels = np.zeros((org_len, args.n_spk))
|
||||
for seg_id, [org_st, org_ed, st, ed] in utt2map:
|
||||
org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr)
|
||||
st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr)
|
||||
ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0])
|
||||
org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :]
|
||||
spk_turns = calc_spk_turns(org_multi_labels, spk_list)
|
||||
spk_turns = sorted(spk_turns, key=lambda x: x[1])
|
||||
for spk, st, dur in spk_turns:
|
||||
# TODO: handle the leak of segments at the change points
|
||||
if dur > args.ignore_len:
|
||||
results.append(template.format(mid, float(st)/100, float(dur)/100, spk))
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
5
egs/alimeeting/diarization/sond/path.sh
Executable file
5
egs/alimeeting/diarization/sond/path.sh
Executable file
@ -0,0 +1,5 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
48
egs/alimeeting/diarization/sond/run.sh
Normal file
48
egs/alimeeting/diarization/sond/run.sh
Normal file
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
stage=0
|
||||
stop_stage=2
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Downloading AliMeeting test set data..."
|
||||
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/alimeeting_test_data_for_sond.tar.gz
|
||||
echo "Done. Extracting data..."
|
||||
tar zxf alimeeting_test_data_for_sond.tar.gz
|
||||
echo "Done."
|
||||
|
||||
echo "Downloading Pre-trained model..."
|
||||
git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
|
||||
git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
|
||||
ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pth
|
||||
cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
|
||||
ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pth
|
||||
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
|
||||
cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
|
||||
echo "Done."
|
||||
|
||||
echo "Downloading dscore for scoring..."
|
||||
git clone https://github.com/nryant/dscore.git
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Calculating diarization results..."
|
||||
python infer_alimeeting_test.py sond_fbank.yaml sond.pth outputs
|
||||
python local/convert_label_to_rttm.py \
|
||||
outputs/labels.txt \
|
||||
data/test_rmsil/raw_rmsil_map.scp \
|
||||
outputs/prediction_sm_83.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 \
|
||||
--vote_prob 0.5 --n_spk 16
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Scoring..."
|
||||
python dscore/score.py \
|
||||
-r data/test_rmsil/test_org.crttm \
|
||||
-s outputs/prediction_sm_83.rttm \
|
||||
--collar 0.25
|
||||
fi
|
||||
97
egs/alimeeting/diarization/sond/unit_test.py
Normal file
97
egs/alimeeting/diarization/sond/unit_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
from funasr.bin.diar_inference_launch import inference_launch
|
||||
import os
|
||||
|
||||
|
||||
def test_fbank_cpu_infer():
|
||||
diar_config_path = "config_fbank.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
num_workers=1,
|
||||
log_level="WARNING",
|
||||
)
|
||||
results = pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_fbank_gpu_infer():
|
||||
diar_config_path = "config_fbank.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
|
||||
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
ngpu=1,
|
||||
num_workers=1,
|
||||
log_level="WARNING",
|
||||
)
|
||||
results = pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_wav_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
output_dir = "./outputs"
|
||||
data_path_and_name_and_type = [
|
||||
("data/unit_test/test_wav.scp", "speech", "sound"),
|
||||
("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
|
||||
]
|
||||
pipeline = inference_launch(
|
||||
mode="sond",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
ngpu=1,
|
||||
num_workers=1,
|
||||
log_level="WARNING",
|
||||
)
|
||||
results = pipeline(data_path_and_name_and_type)
|
||||
print(results)
|
||||
|
||||
|
||||
def test_without_profile_gpu_infer():
|
||||
diar_config_path = "config.yaml"
|
||||
diar_model_path = "sond.pth"
|
||||
output_dir = "./outputs"
|
||||
raw_inputs = [[
|
||||
"data/unit_test/raw_inputs/record.wav",
|
||||
"data/unit_test/raw_inputs/spk1.wav",
|
||||
"data/unit_test/raw_inputs/spk2.wav",
|
||||
"data/unit_test/raw_inputs/spk3.wav",
|
||||
"data/unit_test/raw_inputs/spk4.wav"
|
||||
]]
|
||||
pipeline = inference_launch(
|
||||
mode="sond_demo",
|
||||
diar_train_config=diar_config_path,
|
||||
diar_model_file=diar_model_path,
|
||||
output_dir=output_dir,
|
||||
ngpu=1,
|
||||
num_workers=1,
|
||||
log_level="WARNING",
|
||||
param_dict={},
|
||||
)
|
||||
results = pipeline(raw_inputs=raw_inputs)
|
||||
print(results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
test_fbank_cpu_infer()
|
||||
test_fbank_gpu_infer()
|
||||
test_wav_gpu_infer()
|
||||
test_without_profile_gpu_infer()
|
||||
179
funasr/bin/diar_inference_launch.py
Executable file
179
funasr/bin/diar_inference_launch.py
Executable file
@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Union, Dict, Any
|
||||
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="Speaker Verification",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=False)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--njob",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of jobs for each gpu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpuid_list",
|
||||
type=str,
|
||||
default="",
|
||||
help="The visible gpus",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=False,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--vad_infer_config",
|
||||
type=str,
|
||||
help="VAD infer configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--vad_model_file",
|
||||
type=str,
|
||||
help="VAD model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--diar_train_config",
|
||||
type=str,
|
||||
help="ASR training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--diar_model_file",
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str,
|
||||
help="Global CMVN file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("The inference configuration related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument(
|
||||
"--diar_smooth_size",
|
||||
type=int,
|
||||
default=121,
|
||||
help="The smoothing size for post-processing"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def inference_launch(mode, **kwargs):
|
||||
if mode == "sond":
|
||||
from funasr.bin.sond_inference import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
elif mode == "sond_demo":
|
||||
from funasr.bin.sond_inference import inference_modelscope
|
||||
param_dict = {
|
||||
"extract_profile": True,
|
||||
"sv_train_config": "sv.yaml",
|
||||
"sv_model_file": "sv.pth",
|
||||
}
|
||||
if "param_dict" in kwargs:
|
||||
kwargs["param_dict"].update(param_dict)
|
||||
else:
|
||||
kwargs["param_dict"] = param_dict
|
||||
return inference_modelscope(**kwargs)
|
||||
else:
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="sond",
|
||||
help="The decoding mode",
|
||||
)
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
|
||||
# set logging messages
|
||||
logging.basicConfig(
|
||||
level=args.log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
logging.info("Decoding args: {}".format(kwargs))
|
||||
|
||||
# gpu setting
|
||||
if args.ngpu > 0:
|
||||
jobid = int(args.output_dir.split(".")[-1])
|
||||
gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
|
||||
|
||||
inference_launch(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
544
funasr/bin/sond_inference.py
Executable file
544
funasr/bin/sond_inference.py
Executable file
@ -0,0 +1,544 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.tasks.diar import DiarTask
|
||||
from funasr.tasks.asr import ASRTask
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from scipy.ndimage import median_filter
|
||||
from funasr.utils.misc import statistic_model_parameters
|
||||
|
||||
class Speech2Diarization:
|
||||
"""Speech2Xvector class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import numpy as np
|
||||
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
|
||||
>>> profile = np.load("profiles.npy")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2diar(audio, profile)
|
||||
{"spk1": [(int, int), ...], ...}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
diar_train_config: Union[Path, str] = None,
|
||||
diar_model_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
streaming: bool = False,
|
||||
smooth_size: int = 83,
|
||||
dur_threshold: float = 10,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# TODO: 1. Build Diarization model
|
||||
diar_model, diar_train_args = DiarTask.build_model_from_file(
|
||||
config_file=diar_train_config,
|
||||
model_file=diar_model_file,
|
||||
device=device
|
||||
)
|
||||
logging.info("diar_model: {}".format(diar_model))
|
||||
logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
|
||||
logging.info("diar_train_args: {}".format(diar_train_args))
|
||||
diar_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
self.diar_model = diar_model
|
||||
self.diar_train_args = diar_train_args
|
||||
self.token_list = diar_train_args.token_list
|
||||
self.smooth_size = smooth_size
|
||||
self.dur_threshold = dur_threshold
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
def smooth_multi_labels(self, multi_label):
|
||||
multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
|
||||
return multi_label
|
||||
|
||||
@staticmethod
|
||||
def calc_spk_turns(label_arr, spk_list):
|
||||
turn_list = []
|
||||
length = label_arr.shape[0]
|
||||
n_spk = label_arr.shape[1]
|
||||
for k in range(n_spk):
|
||||
if spk_list[k] == "None":
|
||||
continue
|
||||
in_utt = False
|
||||
start = 0
|
||||
for i in range(length):
|
||||
if label_arr[i, k] == 1 and in_utt is False:
|
||||
start = i
|
||||
in_utt = True
|
||||
if label_arr[i, k] == 0 and in_utt is True:
|
||||
turn_list.append([spk_list[k], start, i - start])
|
||||
in_utt = False
|
||||
if in_utt:
|
||||
turn_list.append([spk_list[k], start, length - start])
|
||||
return turn_list
|
||||
|
||||
@staticmethod
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
def int2vec(x, vec_dim=8, dtype=np.int):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
|
||||
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
|
||||
# upsampling outputs to match inputs
|
||||
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
|
||||
logits_idx = F.upsample(
|
||||
logits_idx.unsqueeze(1).float(),
|
||||
size=(ut, ),
|
||||
mode="nearest",
|
||||
).squeeze(1).long()
|
||||
logits_idx = logits_idx[0].tolist()
|
||||
pse_labels = [self.token_list[x] for x in logits_idx]
|
||||
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
|
||||
multi_labels = self.smooth_multi_labels(multi_labels)
|
||||
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
|
||||
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
|
||||
results = OrderedDict()
|
||||
for spk, st, dur in spk_turns:
|
||||
if spk not in results:
|
||||
results[spk] = []
|
||||
if dur > self.dur_threshold:
|
||||
results[spk].append((st, st+dur))
|
||||
|
||||
# sort segments in start time ascending
|
||||
for spk in results:
|
||||
results[spk] = sorted(results[spk], key=lambda x: x[0])
|
||||
|
||||
return results, pse_labels
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
speech: Union[torch.Tensor, np.ndarray],
|
||||
profile: Union[torch.Tensor, np.ndarray],
|
||||
):
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
profile: Speaker profiles
|
||||
Returns:
|
||||
diarization results for each speaker
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
if isinstance(profile, np.ndarray):
|
||||
profile = torch.tensor(profile)
|
||||
|
||||
# data: (Nsamples,) -> (1, Nsamples)
|
||||
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
# lengths: (1,)
|
||||
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
||||
profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths,
|
||||
"profile": profile, "profile_lengths": profile_lengths}
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
logits = self.diar_model.prediction_forward(**batch)
|
||||
results, pse_labels = self.post_processing(logits, profile.shape[1])
|
||||
|
||||
return results, pse_labels
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_tag: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
):
|
||||
"""Build Speech2Xvector instance from the pretrained model.
|
||||
|
||||
Args:
|
||||
model_tag (Optional[str]): Model tag of the pretrained models.
|
||||
Currently, the tags of espnet_model_zoo are supported.
|
||||
|
||||
Returns:
|
||||
Speech2Xvector: Speech2Xvector instance.
|
||||
|
||||
"""
|
||||
if model_tag is not None:
|
||||
try:
|
||||
from espnet_model_zoo.downloader import ModelDownloader
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"`espnet_model_zoo` is not installed. "
|
||||
"Please install via `pip install -U espnet_model_zoo`."
|
||||
)
|
||||
raise
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2Diarization(**kwargs)
|
||||
|
||||
|
||||
def inference_modelscope(
|
||||
diar_train_config: str,
|
||||
diar_model_file: str,
|
||||
output_dir: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
ngpu: int = 0,
|
||||
seed: int = 0,
|
||||
num_workers: int = 0,
|
||||
log_level: Union[int, str] = "INFO",
|
||||
key_file: Optional[str] = None,
|
||||
model_tag: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = True,
|
||||
streaming: bool = False,
|
||||
smooth_size: int = 83,
|
||||
dur_threshold: int = 10,
|
||||
out_format: str = "vad",
|
||||
param_dict: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
logging.info("param_dict: {}".format(param_dict))
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2a. Build speech2xvec [Optional]
|
||||
if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
|
||||
assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
|
||||
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
|
||||
sv_train_config = param_dict["sv_train_config"]
|
||||
sv_model_file = param_dict["sv_model_file"]
|
||||
from funasr.bin.sv_inference import Speech2Xvector
|
||||
speech2xvector_kwargs = dict(
|
||||
sv_train_config=sv_train_config,
|
||||
sv_model_file=sv_model_file,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
streaming=streaming,
|
||||
embedding_node="resnet1_dense"
|
||||
)
|
||||
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
|
||||
speech2xvector = Speech2Xvector.from_pretrained(
|
||||
model_tag=model_tag,
|
||||
**speech2xvector_kwargs,
|
||||
)
|
||||
speech2xvector.sv_model.eval()
|
||||
|
||||
# 2b. Build speech2diar
|
||||
speech2diar_kwargs = dict(
|
||||
diar_train_config=diar_train_config,
|
||||
diar_model_file=diar_model_file,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
streaming=streaming,
|
||||
smooth_size=smooth_size,
|
||||
dur_threshold=dur_threshold,
|
||||
)
|
||||
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
|
||||
speech2diar = Speech2Diarization.from_pretrained(
|
||||
model_tag=model_tag,
|
||||
**speech2diar_kwargs,
|
||||
)
|
||||
speech2diar.diar_model.eval()
|
||||
|
||||
def output_results_str(results: dict, uttid: str):
|
||||
rst = []
|
||||
mid = uttid.rsplit("-", 1)[0]
|
||||
for key in results:
|
||||
results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
|
||||
if out_format == "vad":
|
||||
for spk, segs in results.items():
|
||||
rst.append("{} {}".format(spk, segs))
|
||||
else:
|
||||
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
|
||||
for spk, segs in results.items():
|
||||
rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
|
||||
|
||||
return "\n".join(rst)
|
||||
|
||||
def _forward(
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
|
||||
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
param_dict: Optional[dict] = None,
|
||||
):
|
||||
logging.info("param_dict: {}".format(param_dict))
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, (list, tuple)):
|
||||
assert all([len(example) >= 2 for example in raw_inputs]), \
|
||||
"The length of test case in raw_inputs must larger than 1 (>=2)."
|
||||
|
||||
def prepare_dataset():
|
||||
for idx, example in enumerate(raw_inputs):
|
||||
# read waveform file
|
||||
example = [soundfile.read(x)[0] if isinstance(example[0], str) else x
|
||||
for x in example]
|
||||
# convert torch tensor to numpy array
|
||||
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
|
||||
for x in example]
|
||||
speech = example[0]
|
||||
logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
|
||||
profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
|
||||
profile = torch.cat(profile, dim=0)
|
||||
yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
|
||||
|
||||
loader = prepare_dataset()
|
||||
else:
|
||||
raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
|
||||
else:
|
||||
# 3. Build data-iterator
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=None,
|
||||
collate_fn=None,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
# 7. Start for-loop
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if output_path is not None:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
output_writer = open("{}/result.txt".format(output_path), "w")
|
||||
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
|
||||
logging.info("Start to diarize...")
|
||||
result_list = []
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
|
||||
results, pse_labels = speech2diar(**batch)
|
||||
# Only supporting batch_size==1
|
||||
key, value = keys[0], output_results_str(results, keys[0])
|
||||
item = {"key": key, "value": value}
|
||||
result_list.append(item)
|
||||
if output_path is not None:
|
||||
output_writer.write(value)
|
||||
output_writer.flush()
|
||||
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
|
||||
pse_label_writer.flush()
|
||||
|
||||
if output_path is not None:
|
||||
output_writer.close()
|
||||
pse_label_writer.close()
|
||||
|
||||
return result_list
|
||||
|
||||
return _forward
|
||||
|
||||
|
||||
def inference(
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
||||
diar_train_config: Optional[str],
|
||||
diar_model_file: Optional[str],
|
||||
output_dir: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
ngpu: int = 0,
|
||||
seed: int = 0,
|
||||
num_workers: int = 1,
|
||||
log_level: Union[int, str] = "INFO",
|
||||
key_file: Optional[str] = None,
|
||||
model_tag: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = True,
|
||||
streaming: bool = False,
|
||||
smooth_size: int = 83,
|
||||
dur_threshold: int = 10,
|
||||
out_format: str = "vad",
|
||||
**kwargs,
|
||||
):
|
||||
inference_pipeline = inference_modelscope(
|
||||
diar_train_config=diar_train_config,
|
||||
diar_model_file=diar_model_file,
|
||||
output_dir=output_dir,
|
||||
batch_size=batch_size,
|
||||
dtype=dtype,
|
||||
ngpu=ngpu,
|
||||
seed=seed,
|
||||
num_workers=num_workers,
|
||||
log_level=log_level,
|
||||
key_file=key_file,
|
||||
model_tag=model_tag,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
streaming=streaming,
|
||||
smooth_size=smooth_size,
|
||||
dur_threshold=dur_threshold,
|
||||
out_format=out_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="Speaker verification/x-vector extraction",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=False)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpuid_list",
|
||||
type=str,
|
||||
default="",
|
||||
help="The visible gpus",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=False,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--diar_train_config",
|
||||
type=str,
|
||||
help="diarization training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--diar_model_file",
|
||||
type=str,
|
||||
help="diarization model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dur_threshold",
|
||||
type=int,
|
||||
default=10,
|
||||
help="The threshold for short segments in number frames"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smooth_size",
|
||||
type=int,
|
||||
default=83,
|
||||
help="The smoothing window length in number frames"
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
parser.add_argument("--streaming", type=str2bool, default=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
logging.info("args: {}".format(kwargs))
|
||||
if args.output_dir is None:
|
||||
jobid, n_gpu = 1, 1
|
||||
gpuid = args.gpuid_list.split(",")[jobid-1]
|
||||
else:
|
||||
jobid = int(args.output_dir.split(".")[-1])
|
||||
n_gpu = len(args.gpuid_list.split(","))
|
||||
gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
|
||||
results_list = inference(**kwargs)
|
||||
for results in results_list:
|
||||
print("{} {}".format(results["key"], results["value"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,4 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@ -26,7 +29,7 @@ from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
from funasr.utils.misc import statistic_model_parameters
|
||||
|
||||
class Speech2Xvector:
|
||||
"""Speech2Xvector class
|
||||
@ -59,6 +62,7 @@ class Speech2Xvector:
|
||||
device=device
|
||||
)
|
||||
logging.info("sv_model: {}".format(sv_model))
|
||||
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
|
||||
logging.info("sv_train_args: {}".format(sv_train_args))
|
||||
sv_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
@ -156,17 +160,17 @@ class Speech2Xvector:
|
||||
|
||||
|
||||
def inference_modelscope(
|
||||
output_dir: Optional[str],
|
||||
batch_size: int,
|
||||
dtype: str,
|
||||
ngpu: int,
|
||||
seed: int,
|
||||
num_workers: int,
|
||||
log_level: Union[int, str],
|
||||
key_file: Optional[str],
|
||||
sv_train_config: Optional[str],
|
||||
sv_model_file: Optional[str],
|
||||
model_tag: Optional[str],
|
||||
output_dir: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
ngpu: int = 1,
|
||||
seed: int = 0,
|
||||
num_workers: int = 0,
|
||||
log_level: Union[int, str] = "INFO",
|
||||
key_file: Optional[str] = None,
|
||||
sv_train_config: Optional[str] = "sv.yaml",
|
||||
sv_model_file: Optional[str] = "sv.pth",
|
||||
model_tag: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = True,
|
||||
streaming: bool = False,
|
||||
embedding_node: str = "resnet1_dense",
|
||||
@ -214,7 +218,6 @@ def inference_modelscope(
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: Optional[dict] = None,
|
||||
):
|
||||
logging.info("param_dict: {}".format(param_dict))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
402
funasr/models/e2e_diar_sond.py
Normal file
402
funasr/models/e2e_diar_sond.py
Normal file
@ -0,0 +1,402 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from itertools import permutations
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import to_device
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class DiarSondModel(AbsESPnetModel):
|
||||
"""Speaker overlap-aware neural diarization model
|
||||
reference: https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: AbsEncoder,
|
||||
speaker_encoder: AbsEncoder,
|
||||
ci_scorer: torch.nn.Module,
|
||||
cd_scorer: torch.nn.Module,
|
||||
decoder: torch.nn.Module,
|
||||
token_list: list,
|
||||
lsm_weight: float = 0.1,
|
||||
length_normalized_loss: bool = False,
|
||||
max_spk_num: int = 16,
|
||||
label_aggregator: Optional[torch.nn.Module] = None,
|
||||
normlize_speech_speaker: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.encoder = encoder
|
||||
self.speaker_encoder = speaker_encoder
|
||||
self.ci_scorer = ci_scorer
|
||||
self.cd_scorer = cd_scorer
|
||||
self.normalize = normalize
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.label_aggregator = label_aggregator
|
||||
self.decoder = decoder
|
||||
self.token_list = token_list
|
||||
self.max_spk_num = max_spk_num
|
||||
self.normalize_speech_speaker = normlize_speech_speaker
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor = None,
|
||||
profile: torch.Tensor = None,
|
||||
profile_lengths: torch.Tensor = None,
|
||||
spk_labels: torch.Tensor = None,
|
||||
spk_labels_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, samples)
|
||||
speech_lengths: (Batch,) default None for chunk interator,
|
||||
because the chunk-iterator does not
|
||||
have the speech_lengths returned.
|
||||
see in
|
||||
espnet2/iterators/chunk_iter_factory.py
|
||||
profile: (Batch, N_spk, dim)
|
||||
profile_lengths: (Batch,)
|
||||
spk_labels: (Batch, )
|
||||
"""
|
||||
assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
if self.attractor is None:
|
||||
# 2a. Decoder (baiscally a predction layer after encoder_out)
|
||||
pred = self.decoder(encoder_out, encoder_out_lens)
|
||||
else:
|
||||
# 2b. Encoder Decoder Attractors
|
||||
# Shuffle the chronological order of encoder_out, then calculate attractor
|
||||
encoder_out_shuffled = encoder_out.clone()
|
||||
for i in range(len(encoder_out_lens)):
|
||||
encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
|
||||
i, torch.randperm(encoder_out_lens[i]), :
|
||||
]
|
||||
attractor, att_prob = self.attractor(
|
||||
encoder_out_shuffled,
|
||||
encoder_out_lens,
|
||||
to_device(
|
||||
self,
|
||||
torch.zeros(
|
||||
encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
|
||||
),
|
||||
),
|
||||
)
|
||||
# Remove the final attractor which does not correspond to a speaker
|
||||
# Then multiply the attractors and encoder_out
|
||||
pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
|
||||
# 3. Aggregate time-domain labels
|
||||
if self.label_aggregator is not None:
|
||||
spk_labels, spk_labels_lengths = self.label_aggregator(
|
||||
spk_labels, spk_labels_lengths
|
||||
)
|
||||
|
||||
# If encoder uses conv* as input_layer (i.e., subsampling),
|
||||
# the sequence length of 'pred' might be slighly less than the
|
||||
# length of 'spk_labels'. Here we force them to be equal.
|
||||
length_diff_tolerance = 2
|
||||
length_diff = spk_labels.shape[1] - pred.shape[1]
|
||||
if length_diff > 0 and length_diff <= length_diff_tolerance:
|
||||
spk_labels = spk_labels[:, 0 : pred.shape[1], :]
|
||||
|
||||
if self.attractor is None:
|
||||
loss_pit, loss_att = None, None
|
||||
loss, perm_idx, perm_list, label_perm = self.pit_loss(
|
||||
pred, spk_labels, encoder_out_lens
|
||||
)
|
||||
else:
|
||||
loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
|
||||
pred, spk_labels, encoder_out_lens
|
||||
)
|
||||
loss_att = self.attractor_loss(att_prob, spk_labels)
|
||||
loss = loss_pit + self.attractor_weight * loss_att
|
||||
(
|
||||
correct,
|
||||
num_frames,
|
||||
speech_scored,
|
||||
speech_miss,
|
||||
speech_falarm,
|
||||
speaker_scored,
|
||||
speaker_miss,
|
||||
speaker_falarm,
|
||||
speaker_error,
|
||||
) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
|
||||
|
||||
if speech_scored > 0 and num_frames > 0:
|
||||
sad_mr, sad_fr, mi, fa, cf, acc, der = (
|
||||
speech_miss / speech_scored,
|
||||
speech_falarm / speech_scored,
|
||||
speaker_miss / speaker_scored,
|
||||
speaker_falarm / speaker_scored,
|
||||
speaker_error / speaker_scored,
|
||||
correct / num_frames,
|
||||
(speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
|
||||
)
|
||||
else:
|
||||
sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_att=loss_att.detach() if loss_att is not None else None,
|
||||
loss_pit=loss_pit.detach() if loss_pit is not None else None,
|
||||
sad_mr=sad_mr,
|
||||
sad_fr=sad_fr,
|
||||
mi=mi,
|
||||
fa=fa,
|
||||
cf=cf,
|
||||
acc=acc,
|
||||
der=der,
|
||||
)
|
||||
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
spk_labels: torch.Tensor = None,
|
||||
spk_labels_lengths: torch.Tensor = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode_speaker(
|
||||
self,
|
||||
profile: torch.Tensor,
|
||||
profile_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with autocast(False):
|
||||
if profile.shape[1] < self.max_spk_num:
|
||||
profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0)
|
||||
profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()
|
||||
profile = F.normalize(profile, dim=2)
|
||||
if self.speaker_encoder is not None:
|
||||
profile = self.speaker_encoder(profile, profile_lengths)[0]
|
||||
return profile * profile_mask, profile_lengths
|
||||
else:
|
||||
return profile, profile_lengths
|
||||
|
||||
def encode_speech(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.encoder is not None:
|
||||
speech, speech_lengths = self.encode(speech, speech_lengths)
|
||||
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
|
||||
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
|
||||
return speech * speech_mask, speech_lengths
|
||||
else:
|
||||
return speech, speech_lengths
|
||||
|
||||
@staticmethod
|
||||
def concate_speech_ivc(
|
||||
speech: torch.Tensor,
|
||||
ivc: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
nn, tt = ivc.shape[1], speech.shape[1]
|
||||
speech = speech.unsqueeze(dim=1) # B x 1 x T x D
|
||||
speech = speech.expand(-1, nn, -1, -1) # B x N x T x D
|
||||
ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D
|
||||
ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D
|
||||
sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D
|
||||
return sd_in
|
||||
|
||||
def calc_similarity(
|
||||
self,
|
||||
speech_encoder_outputs: torch.Tensor,
|
||||
speaker_encoder_outputs: torch.Tensor,
|
||||
seq_len: torch.Tensor = None,
|
||||
spk_len: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
|
||||
d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
|
||||
if self.normalize_speech_speaker:
|
||||
speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2)
|
||||
speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2)
|
||||
ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs)
|
||||
ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk])
|
||||
ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num)
|
||||
ge_len = torch.reshape(ge_len, [bb * self.max_spk_num])
|
||||
cd_simi = self.cd_scorer(ge_in, ge_len)[0]
|
||||
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
|
||||
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
|
||||
|
||||
if isinstance(self.ci_scorer, AbsEncoder):
|
||||
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
|
||||
else:
|
||||
ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
|
||||
simi = torch.cat([cd_simi, ci_simi], dim=2)
|
||||
|
||||
return simi
|
||||
|
||||
def post_net_forward(self, simi, seq_len):
|
||||
logits = self.decoder(simi, seq_len)[0]
|
||||
|
||||
return logits
|
||||
|
||||
def prediction_forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# speech encoding
|
||||
speech, speech_lengths = self.encode_speech(speech, speech_lengths)
|
||||
# speaker encoding
|
||||
profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
|
||||
# calculating similarity
|
||||
similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
|
||||
# post net forward
|
||||
logits = self.post_net_forward(similarity, speech_lengths)
|
||||
|
||||
return logits
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch,)
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = speech.shape[0]
|
||||
speech_lengths = (
|
||||
speech_lengths
|
||||
if speech_lengths is not None
|
||||
else torch.ones(batch_size).int() * speech.shape[1]
|
||||
)
|
||||
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
@staticmethod
|
||||
def calc_diarization_error(pred, label, length):
|
||||
# Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
|
||||
|
||||
(batch_size, max_len, num_output) = label.size()
|
||||
# mask the padding part
|
||||
mask = np.zeros((batch_size, max_len, num_output))
|
||||
for i in range(batch_size):
|
||||
mask[i, : length[i], :] = 1
|
||||
|
||||
# pred and label have the shape (batch_size, max_len, num_output)
|
||||
label_np = label.data.cpu().numpy().astype(int)
|
||||
pred_np = (pred.data.cpu().numpy() > 0).astype(int)
|
||||
label_np = label_np * mask
|
||||
pred_np = pred_np * mask
|
||||
length = length.data.cpu().numpy()
|
||||
|
||||
# compute speech activity detection error
|
||||
n_ref = np.sum(label_np, axis=2)
|
||||
n_sys = np.sum(pred_np, axis=2)
|
||||
speech_scored = float(np.sum(n_ref > 0))
|
||||
speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
|
||||
speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
|
||||
|
||||
# compute speaker diarization error
|
||||
speaker_scored = float(np.sum(n_ref))
|
||||
speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
|
||||
speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
|
||||
n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
|
||||
speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
|
||||
correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
|
||||
num_frames = np.sum(length)
|
||||
return (
|
||||
correct,
|
||||
num_frames,
|
||||
speech_scored,
|
||||
speech_miss,
|
||||
speech_falarm,
|
||||
speaker_scored,
|
||||
speaker_miss,
|
||||
speaker_falarm,
|
||||
speaker_error,
|
||||
)
|
||||
0
funasr/models/encoder/opennmt_encoders/__init__.py
Normal file
0
funasr/models/encoder/opennmt_encoders/__init__.py
Normal file
38
funasr/models/encoder/opennmt_encoders/ci_scorers.py
Normal file
38
funasr/models/encoder/opennmt_encoders/ci_scorers.py
Normal file
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class DotScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
spk_emb: torch.Tensor,
|
||||
):
|
||||
# xs_pad: B, T, D
|
||||
# spk_emb: B, N, D
|
||||
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
|
||||
return scores
|
||||
|
||||
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
|
||||
return {}
|
||||
|
||||
|
||||
class CosScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
spk_emb: torch.Tensor,
|
||||
):
|
||||
# xs_pad: B, T, D
|
||||
# spk_emb: B, N, D
|
||||
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
|
||||
return scores
|
||||
|
||||
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
|
||||
return {}
|
||||
277
funasr/models/encoder/opennmt_encoders/conv_encoder.py
Normal file
277
funasr/models/encoder/opennmt_encoders/conv_encoder.py
Normal file
@ -0,0 +1,277 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
import math
|
||||
from funasr.modules.repeat import repeat
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size=3,
|
||||
activation="tanh",
|
||||
stride=1,
|
||||
include_batch_norm=False,
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
left_padding = math.ceil((kernel_size - stride) / 2)
|
||||
right_padding = kernel_size - stride - left_padding
|
||||
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.conv1d = nn.Conv1d(
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
stride,
|
||||
)
|
||||
self.activation = self.get_activation(activation)
|
||||
if include_batch_norm:
|
||||
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
|
||||
self.residual = residual
|
||||
self.include_batch_norm = include_batch_norm
|
||||
self.input_units = input_units
|
||||
self.num_units = num_units
|
||||
self.stride = stride
|
||||
|
||||
@staticmethod
|
||||
def get_activation(activation):
|
||||
if activation == "tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
return nn.ReLU()
|
||||
|
||||
def forward(self, xs_pad, ilens=None):
|
||||
outputs = self.conv1d(self.conv_padding(xs_pad))
|
||||
if self.residual and self.stride == 1 and self.input_units == self.num_units:
|
||||
outputs = outputs + xs_pad
|
||||
|
||||
if self.include_batch_norm:
|
||||
outputs = self.bn(outputs)
|
||||
|
||||
# add parenthesis for repeat module
|
||||
return self.activation(outputs), ilens
|
||||
|
||||
|
||||
class ConvEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Convolution encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size=3,
|
||||
dropout_rate=0.3,
|
||||
position_encoder=None,
|
||||
activation='tanh',
|
||||
auxiliary_states=True,
|
||||
out_units=None,
|
||||
out_norm=False,
|
||||
out_residual=False,
|
||||
include_batchnorm=False,
|
||||
regularization_weight=0.0,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = num_units
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.input_units = input_units
|
||||
self.num_units = num_units
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.position_encoder = position_encoder
|
||||
self.out_units = out_units
|
||||
self.auxiliary_states = auxiliary_states
|
||||
self.out_norm = out_norm
|
||||
self.activation = activation
|
||||
self.out_residual = out_residual
|
||||
self.include_batch_norm = include_batchnorm
|
||||
self.regularization_weight = regularization_weight
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
if isinstance(stride, int):
|
||||
self.stride = [stride] * self.num_layers
|
||||
else:
|
||||
self.stride = stride
|
||||
self.downsample_rate = 1
|
||||
for s in self.stride:
|
||||
self.downsample_rate *= s
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.cnn_a = repeat(
|
||||
self.num_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
input_units if lnum == 0 else num_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
activation,
|
||||
self.stride[lnum],
|
||||
include_batchnorm,
|
||||
residual=True if lnum > 0 else False
|
||||
)
|
||||
)
|
||||
|
||||
if self.out_units is not None:
|
||||
left_padding = math.ceil((kernel_size - stride) / 2)
|
||||
right_padding = kernel_size - stride - left_padding
|
||||
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.conv_out = nn.Conv1d(
|
||||
num_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
if self.out_norm:
|
||||
self.after_norm = LayerNorm(num_units)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_units
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
inputs = xs_pad
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
inputs = self.dropout(inputs)
|
||||
|
||||
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
|
||||
|
||||
if self.out_units is not None:
|
||||
outputs = self.conv_out(self.out_padding(outputs))
|
||||
|
||||
outputs = outputs.transpose(1, 2)
|
||||
if self.out_norm:
|
||||
outputs = self.after_norm(outputs)
|
||||
|
||||
if self.out_residual:
|
||||
outputs = outputs + inputs
|
||||
|
||||
return outputs, ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
|
||||
"{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
}
|
||||
if self.out_units is not None:
|
||||
# add output layer
|
||||
map_dict_local.update({
|
||||
"{}.conv_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # tf: (1, 256, 256) -> torch: (256, 256, 1)
|
||||
"{}.conv_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # tf: (256,) -> torch: (256,)
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
# process special (first and last) layers
|
||||
if name in map_dict:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
# process general layers
|
||||
else:
|
||||
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
|
||||
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
335
funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
Normal file
335
funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
Normal file
@ -0,0 +1,335 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
import math
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.multi_layer_conv import FsmnFeedForward
|
||||
|
||||
|
||||
class FsmnBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
fsmn_shift=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
|
||||
padding=0, groups=n_feat, bias=False)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if fsmn_shift > 0:
|
||||
left_padding = left_padding + fsmn_shift
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward(self, inputs, mask, mask_shfit_chunk=None):
|
||||
b, t, d = inputs.size()
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
|
||||
inputs = inputs * mask
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = x + inputs
|
||||
x = self.dropout(x)
|
||||
return x * mask
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
feed_forward,
|
||||
fsmn_block,
|
||||
dropout_rate=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.ffn = feed_forward
|
||||
self.memory = fsmn_block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# xs_pad in Batch, Time, Dim
|
||||
|
||||
context = self.ffn(xs_pad)[0]
|
||||
memory = self.memory(context, mask)
|
||||
|
||||
memory = self.dropout(memory)
|
||||
if self.in_size == self.size:
|
||||
return memory + xs_pad, mask
|
||||
|
||||
return memory, mask
|
||||
|
||||
|
||||
class FsmnEncoder(AbsEncoder):
|
||||
"""Encoder using Fsmn
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_units,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
dnn_num_layers,
|
||||
num_memory_units=512,
|
||||
ffn_inner_dim=2048,
|
||||
dropout_rate=0.0,
|
||||
shift=0,
|
||||
position_encoder=None,
|
||||
sample_rate=1,
|
||||
out_units=None,
|
||||
tf2torch_tensor_name_prefix_torch="post_net",
|
||||
tf2torch_tensor_name_prefix_tf="EAND/post_net"
|
||||
):
|
||||
"""Initializes the parameters of the encoder.
|
||||
|
||||
Args:
|
||||
filter_size: the total order of memory block
|
||||
fsmn_num_layers: The number of fsmn layers.
|
||||
dnn_num_layers: The number of dnn layers
|
||||
num_units: The number of memory units.
|
||||
ffn_inner_dim: The number of units of the inner linear transformation
|
||||
in the feed forward layer.
|
||||
dropout_rate: The probability to drop units from the outputs.
|
||||
shift: left padding, to control delay
|
||||
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
|
||||
apply on inputs or ``None``.
|
||||
"""
|
||||
super(FsmnEncoder, self).__init__()
|
||||
self.in_units = in_units
|
||||
self.filter_size = filter_size
|
||||
self.fsmn_num_layers = fsmn_num_layers
|
||||
self.dnn_num_layers = dnn_num_layers
|
||||
self.num_memory_units = num_memory_units
|
||||
self.ffn_inner_dim = ffn_inner_dim
|
||||
self.dropout_rate = dropout_rate
|
||||
self.shift = shift
|
||||
if not isinstance(shift, list):
|
||||
self.shift = [shift for _ in range(self.fsmn_num_layers)]
|
||||
self.sample_rate = sample_rate
|
||||
if not isinstance(sample_rate, list):
|
||||
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
|
||||
self.position_encoder = position_encoder
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.out_units = out_units
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.fsmn_layers = repeat(
|
||||
self.fsmn_num_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
in_units if lnum == 0 else num_memory_units,
|
||||
num_memory_units,
|
||||
FsmnFeedForward(
|
||||
in_units if lnum == 0 else num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
1,
|
||||
dropout_rate
|
||||
),
|
||||
FsmnBlock(
|
||||
num_memory_units,
|
||||
dropout_rate,
|
||||
filter_size,
|
||||
self.shift[lnum]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.dnn_layers = repeat(
|
||||
dnn_num_layers,
|
||||
lambda lnum: FsmnFeedForward(
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
1,
|
||||
dropout_rate,
|
||||
)
|
||||
)
|
||||
if out_units is not None:
|
||||
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_memory_units
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs = xs_pad
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
inputs = self.dropout(inputs)
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
inputs = self.fsmn_layers(inputs, masks)[0]
|
||||
inputs = self.dnn_layers(inputs)[0]
|
||||
|
||||
if self.out_units is not None:
|
||||
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
return inputs, ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
# for fsmn_layers
|
||||
"{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
|
||||
|
||||
# for dnn_layers
|
||||
"{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
|
||||
}
|
||||
if self.out_units is not None:
|
||||
# add output layer
|
||||
map_dict_local.update({
|
||||
"{}.conv1d.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
},
|
||||
"{}.conv1d.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
# process special (first and last) layers
|
||||
if name in map_dict:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
# process general layers
|
||||
else:
|
||||
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
|
||||
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
480
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
Normal file
480
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
Normal file
@ -0,0 +1,480 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
return x, mask, cache, mask_att_chunk_encoder
|
||||
|
||||
|
||||
class SelfAttentionEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Self attention encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
out_units=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
elif input_layer == "null":
|
||||
self.embed = None
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
output_size,
|
||||
MultiHeadSelfAttention(
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
) if lnum > 0 else EncoderLayer(
|
||||
input_size,
|
||||
output_size,
|
||||
MultiHeadSelfAttention(
|
||||
attention_heads,
|
||||
input_size if input_layer == "pe" or input_layer == "null" else output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.out_units = out_units
|
||||
if out_units is not None:
|
||||
self.output_linear = nn.Linear(output_size, out_units)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
xs_pad = self.dropout(xs_pad)
|
||||
# encoder_outs = self.encoders0(xs_pad, masks)
|
||||
# xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if self.out_units is not None:
|
||||
xs_pad = self.output_linear(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
# cicd
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (768,256),(1,256,768)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (768,),(768,)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# ffn
|
||||
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
}
|
||||
if self.out_units is not None:
|
||||
map_dict_local.update({
|
||||
"{}.output_linear.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
},
|
||||
"{}.output_linear.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
# process special (first and last) layers
|
||||
if name in map_dict:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
# process general layers
|
||||
else:
|
||||
# self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
|
||||
names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
@ -1,7 +1,11 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BasicLayer(torch.nn.Module):
|
||||
@ -116,10 +120,18 @@ class ResNet34(AbsEncoder):
|
||||
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
|
||||
self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.time_ds_ratio = 8
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
features = xs_pad
|
||||
assert features.size(-1) == self.input_size, \
|
||||
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
|
||||
@ -141,4 +153,463 @@ class ResNet34(AbsEncoder):
|
||||
features = F.relu(features)
|
||||
features = self.resnet0_bn(features)
|
||||
|
||||
return features, ilens // 8
|
||||
return features, resnet_out_lens
|
||||
|
||||
# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
|
||||
# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
|
||||
class ResNet34_SP_L2Reg(AbsEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
|
||||
tf_train_steps=720000,
|
||||
):
|
||||
super(ResNet34_SP_L2Reg, self).__init__()
|
||||
|
||||
self.use_head_conv = use_head_conv
|
||||
self.use_head_maxpool = use_head_maxpool
|
||||
self.num_nodes_pooling_layer = num_nodes_pooling_layer
|
||||
self.layers_in_block = layers_in_block
|
||||
self.filters_in_block = filters_in_block
|
||||
self.input_size = input_size
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.tf_train_steps = tf_train_steps
|
||||
|
||||
pre_filters = filters_in_block[0]
|
||||
if use_head_conv:
|
||||
self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
|
||||
self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
if use_head_maxpool:
|
||||
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
|
||||
|
||||
for i in range(len(layers_in_block)):
|
||||
if i == 0:
|
||||
in_filters = pre_filters if self.use_head_conv else 1
|
||||
else:
|
||||
in_filters = filters_in_block[i-1]
|
||||
|
||||
block = BasicBlock(in_filters,
|
||||
filters=filters_in_block[i],
|
||||
num_layer=layers_in_block[i],
|
||||
stride=1 if i == 0 else 2,
|
||||
bn_momentum=batchnorm_momentum)
|
||||
self.add_module("block_{}".format(i), block)
|
||||
|
||||
self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1)
|
||||
self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.time_ds_ratio = 8
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
features = xs_pad
|
||||
assert features.size(-1) == self.input_size, \
|
||||
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
|
||||
features = torch.unsqueeze(features, dim=1)
|
||||
if self.use_head_conv:
|
||||
features = self.pre_conv(features)
|
||||
features = self.pre_conv_bn(features)
|
||||
features = F.relu(features)
|
||||
|
||||
if self.use_head_maxpool:
|
||||
features = self.head_maxpool(features)
|
||||
|
||||
resnet_outs, resnet_out_lens = features, ilens
|
||||
for i in range(len(self.layers_in_block)):
|
||||
block = self._modules["block_{}".format(i)]
|
||||
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
|
||||
|
||||
# B, C, T, F
|
||||
bb, cc, tt, ff = resnet_outs.shape
|
||||
resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt])
|
||||
features = self.resnet0_dense(resnet_outs)
|
||||
features = F.relu(features)
|
||||
features = self.resnet0_bn(features)
|
||||
|
||||
return features, resnet_out_lens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
train_steps = self.tf_train_steps
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
|
||||
}
|
||||
for layer_idx in range(3):
|
||||
map_dict_local.update({
|
||||
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
|
||||
},
|
||||
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
|
||||
})
|
||||
|
||||
for block_idx in range(len(self.layers_in_block)):
|
||||
for layer_idx in range(self.layers_in_block[block_idx]):
|
||||
for i in ["1", "2", "_sc"]:
|
||||
map_dict_local.update({
|
||||
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
if name in map_dict:
|
||||
if "num_batches_tracked" not in name:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
|
||||
logging.info("torch tensor: {}, manually assigning to: {}".format(
|
||||
name, map_dict[name]
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
|
||||
class ResNet34Diar(ResNet34):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
embedding_node="resnet1_dense",
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
num_nodes_resnet1=256,
|
||||
num_nodes_last_layer=256,
|
||||
pooling_type="window_shift",
|
||||
pool_size=20,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
|
||||
):
|
||||
super(ResNet34Diar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
batchnorm_momentum=batchnorm_momentum,
|
||||
use_head_maxpool=use_head_maxpool,
|
||||
num_nodes_pooling_layer=num_nodes_pooling_layer,
|
||||
layers_in_block=layers_in_block,
|
||||
filters_in_block=filters_in_block,
|
||||
)
|
||||
|
||||
self.embedding_node = embedding_node
|
||||
self.num_nodes_resnet1 = num_nodes_resnet1
|
||||
self.num_nodes_last_layer = num_nodes_last_layer
|
||||
self.pooling_type = pooling_type
|
||||
self.pool_size = pool_size
|
||||
self.stride = stride
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
def output_size(self) -> int:
|
||||
if self.embedding_node.startswith("resnet1"):
|
||||
return self.num_nodes_resnet1
|
||||
elif self.embedding_node.startswith("resnet2"):
|
||||
return self.num_nodes_last_layer
|
||||
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
endpoints = OrderedDict()
|
||||
res_out, ilens = super().forward(xs_pad, ilens)
|
||||
endpoints["resnet0_bn"] = res_out
|
||||
if self.pooling_type == "frame_gsp":
|
||||
features = statistic_pooling(res_out, ilens, (3, ))
|
||||
else:
|
||||
features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride)
|
||||
features = features.transpose(1, 2)
|
||||
endpoints["pooling"] = features
|
||||
|
||||
features = self.resnet1_dense(features)
|
||||
endpoints["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet1_relu"] = features
|
||||
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet1_bn"] = features
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
endpoints["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet2_relu"] = features
|
||||
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet2_bn"] = features
|
||||
|
||||
return endpoints[self.embedding_node], ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
train_steps = 300000
|
||||
map_dict_local = {
|
||||
# torch: conv1d.weight in "out_channel in_channel kernel_size"
|
||||
# tf : conv1d.weight in "kernel_size in_channel out_channel"
|
||||
# torch: linear.weight in "out_channel in_channel"
|
||||
# tf : dense.weight in "in_channel out_channel"
|
||||
"{}.pre_conv.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
|
||||
}
|
||||
for layer_idx in range(3):
|
||||
map_dict_local.update({
|
||||
"{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
|
||||
},
|
||||
"{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
|
||||
{"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
|
||||
})
|
||||
|
||||
for block_idx in range(len(self.layers_in_block)):
|
||||
for layer_idx in range(self.layers_in_block[block_idx]):
|
||||
for i in ["1", "2", "_sc"]:
|
||||
map_dict_local.update({
|
||||
"{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": (3, 2, 0, 1),
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
|
||||
{"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
},
|
||||
"{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
|
||||
})
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
if name.startswith(self.tf2torch_tensor_name_prefix_torch):
|
||||
if name in map_dict:
|
||||
if "num_batches_tracked" not in name:
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), \
|
||||
"{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[name].size(), data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
|
||||
name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
|
||||
))
|
||||
else:
|
||||
var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
|
||||
logging.info("torch tensor: {}, manually assigning to: {}".format(
|
||||
name, map_dict[name]
|
||||
))
|
||||
else:
|
||||
logging.warning("{} is missed from tf checkpoint".format(name))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
@ -90,7 +90,9 @@ class WavFrontend(AbsFrontend):
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
@ -105,6 +107,8 @@ class WavFrontend(AbsFrontend):
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
@ -119,7 +123,8 @@ class WavFrontend(AbsFrontend):
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
waveform = waveform * (1 << 15)
|
||||
if self.upsacle_samples:
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.n_mels,
|
||||
@ -128,7 +133,8 @@ class WavFrontend(AbsFrontend):
|
||||
dither=self.dither,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs)
|
||||
sample_frequency=self.fs,
|
||||
snip_edges=self.snip_edges)
|
||||
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
|
||||
@ -2,7 +2,10 @@ import torch
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from funasr.modules.nets_utils import make_non_pad_mask
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
|
||||
VAR2STD_EPSILON = 1e-12
|
||||
|
||||
class StatisticPooling(torch.nn.Module):
|
||||
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
|
||||
@ -34,3 +37,59 @@ class StatisticPooling(torch.nn.Module):
|
||||
stat_pooling = torch.cat([mean, stddev], dim=1)
|
||||
|
||||
return stat_pooling
|
||||
|
||||
def convert_tf2torch(self, var_dict_tf, var_dict_torch):
|
||||
return {}
|
||||
|
||||
|
||||
def statistic_pooling(
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
pooling_dim: Tuple = (2, 3)
|
||||
) -> torch.Tensor:
|
||||
# xs_pad in (Batch, Channel, Time, Frequency)
|
||||
|
||||
if ilens is None:
|
||||
seq_mask = torch.ones_like(xs_pad).to(xs_pad)
|
||||
else:
|
||||
seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
|
||||
mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
|
||||
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
|
||||
squared_difference = torch.pow(xs_pad - mean, 2.0)
|
||||
variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
|
||||
torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
|
||||
for i in reversed(pooling_dim):
|
||||
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
|
||||
|
||||
value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float()
|
||||
variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON
|
||||
stddev = torch.sqrt(variance)
|
||||
|
||||
stat_pooling = torch.cat([mean, stddev], dim=1)
|
||||
|
||||
return stat_pooling
|
||||
|
||||
|
||||
def windowed_statistic_pooling(
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
pooling_dim: Tuple = (2, 3),
|
||||
pooling_size: int = 20,
|
||||
pooling_stride: int = 1
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
# xs_pad in (Batch, Channel, Time, Frequency)
|
||||
|
||||
tt = xs_pad.shape[2]
|
||||
num_chunk = int(math.ceil(tt / pooling_stride))
|
||||
pad = pooling_size // 2
|
||||
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
|
||||
stat_list = []
|
||||
|
||||
for i in range(num_chunk):
|
||||
# B x C
|
||||
st, ed = i*pooling_stride, i*pooling_stride+pooling_size
|
||||
stat = statistic_pooling(features[:, :, st: ed, :], pooling_dim=pooling_dim)
|
||||
stat_list.append(stat.unsqueeze(2))
|
||||
|
||||
# B x C x T
|
||||
return torch.cat(stat_list, dim=2), ilens / pooling_stride
|
||||
|
||||
@ -622,4 +622,108 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
|
||||
q_h, k_h, v_h = self.forward_qkv(x, memory)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
return self.forward_attention(v_h, scores, memory_mask)
|
||||
return self.forward_attention(v_h, scores, memory_mask)
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, in_feat, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadSelfAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, x):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
b, t, d = x.size()
|
||||
q_k_v = self.linear_q_k_v(x)
|
||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
||||
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q_h, k_h, v_h, v
|
||||
|
||||
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
if mask_att_chunk_encoder is not None:
|
||||
mask = mask * mask_att_chunk_encoder
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, x, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
||||
return att_outs
|
||||
|
||||
@ -63,6 +63,58 @@ class MultiLayeredConv1d(torch.nn.Module):
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class FsmnFeedForward(torch.nn.Module):
|
||||
"""Position-wise feed forward for FSMN blocks.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace position-wise feed-forward network
|
||||
in FSMN block.
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
|
||||
"""Initialize FsmnFeedForward module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
out_chans (int): Number of output channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(FsmnFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
out_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
bias=False
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(hidden_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x, ilens=None):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of output tensors (B, T, out_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
|
||||
585
funasr/tasks/diar.py
Normal file
585
funasr/tasks/diar.py
Normal file
@ -0,0 +1,585 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.layers.label_aggregation import LabelAggregate
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
|
||||
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
|
||||
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
|
||||
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
|
||||
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
|
||||
from funasr.models.e2e_diar_sond import DiarSondModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.fused import FusedFrontends
|
||||
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
|
||||
HuggingFaceTransformersPostEncoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.preencoder.linear import LinearProjection
|
||||
from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
from funasr.utils.types import float_or_none
|
||||
from funasr.utils.types import int_or_none
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
frontend_choices = ClassChoices(
|
||||
name="frontend",
|
||||
classes=dict(
|
||||
default=DefaultFrontend,
|
||||
sliding_window=SlidingWindow,
|
||||
s3prl=S3prlFrontend,
|
||||
fused=FusedFrontends,
|
||||
wav_frontend=WavFrontend,
|
||||
),
|
||||
type_check=AbsFrontend,
|
||||
default="default",
|
||||
)
|
||||
specaug_choices = ClassChoices(
|
||||
name="specaug",
|
||||
classes=dict(
|
||||
specaug=SpecAug,
|
||||
specaug_lfr=SpecAugLFR,
|
||||
),
|
||||
type_check=AbsSpecAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
global_mvn=GlobalMVN,
|
||||
utterance_mvn=UtteranceMVN,
|
||||
),
|
||||
type_check=AbsNormalize,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
label_aggregator_choices = ClassChoices(
|
||||
"label_aggregator",
|
||||
classes=dict(
|
||||
label_aggregator=LabelAggregate
|
||||
),
|
||||
type_check=torch.nn.Module,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
sond=DiarSondModel,
|
||||
),
|
||||
type_check=AbsESPnetModel,
|
||||
default="sond",
|
||||
)
|
||||
encoder_choices = ClassChoices(
|
||||
"encoder",
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
rnn=RNNEncoder,
|
||||
sanm=SANMEncoder,
|
||||
san=SelfAttentionEncoder,
|
||||
fsmn=FsmnEncoder,
|
||||
conv=ConvEncoder,
|
||||
resnet34=ResNet34Diar,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default="resnet34",
|
||||
)
|
||||
speaker_encoder_choices = ClassChoices(
|
||||
"speaker_encoder",
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
rnn=RNNEncoder,
|
||||
sanm=SANMEncoder,
|
||||
san=SelfAttentionEncoder,
|
||||
fsmn=FsmnEncoder,
|
||||
conv=ConvEncoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default=None,
|
||||
optional=True
|
||||
)
|
||||
cd_scorer_choices = ClassChoices(
|
||||
"cd_scorer",
|
||||
classes=dict(
|
||||
san=SelfAttentionEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
ci_scorer_choices = ClassChoices(
|
||||
"ci_scorer",
|
||||
classes=dict(
|
||||
dot=DotScorer,
|
||||
cosine=CosScorer,
|
||||
),
|
||||
type_check=torch.nn.Module,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
# decoder is used for output (e.g. post_net in SOND)
|
||||
decoder_choices = ClassChoices(
|
||||
"decoder",
|
||||
classes=dict(
|
||||
rnn=RNNEncoder,
|
||||
fsmn=FsmnEncoder,
|
||||
),
|
||||
type_check=torch.nn.Module,
|
||||
default="fsmn",
|
||||
)
|
||||
|
||||
|
||||
class DiarTask(AbsTask):
|
||||
# If you need more than 1 optimizer, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --speaker_encoder and --speaker_encoder_conf
|
||||
speaker_encoder_choices,
|
||||
# --cd_scorer and cd_scorer_conf
|
||||
cd_scorer_choices,
|
||||
# --ci_scorer and ci_scorer_conf
|
||||
ci_scorer_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(description="Task related")
|
||||
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
# required = parser.get_default("required")
|
||||
# required += ["token_list"]
|
||||
|
||||
group.add_argument(
|
||||
"--token_list",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="A text mapping int-id to token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--split_with_space",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to split text using <space>",
|
||||
)
|
||||
group.add_argument(
|
||||
"--seg_dict_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="seg_dict_file for text processing",
|
||||
)
|
||||
group.add_argument(
|
||||
"--init",
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help="The initialization method",
|
||||
choices=[
|
||||
"chainer",
|
||||
"xavier_uniform",
|
||||
"xavier_normal",
|
||||
"kaiming_uniform",
|
||||
"kaiming_normal",
|
||||
None,
|
||||
],
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of input dimension of the feature",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(description="Preprocess related")
|
||||
group.add_argument(
|
||||
"--use_preprocessor",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Apply preprocessing to data or not",
|
||||
)
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str,
|
||||
default="char",
|
||||
choices=["char"],
|
||||
help="The text will be tokenized in the specified level token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speech_volume_normalize",
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help="Scale the maximum amplitude to the given value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of rir scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="THe probability for applying RIR convolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The probability applying Noise adding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_db_range",
|
||||
type=str,
|
||||
default="13_15",
|
||||
help="The range of noise decibel level.",
|
||||
)
|
||||
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[
|
||||
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||
]:
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
assert check_argument_types()
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
token_type=args.token_type,
|
||||
token_list=args.token_list,
|
||||
bpemodel=None,
|
||||
non_linguistic_symbols=None,
|
||||
text_cleaner=None,
|
||||
g2p_type=None,
|
||||
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
|
||||
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||
rir_apply_prob=args.rir_apply_prob
|
||||
if hasattr(args, "rir_apply_prob")
|
||||
else 1.0,
|
||||
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||
noise_apply_prob=args.noise_apply_prob
|
||||
if hasattr(args, "noise_apply_prob")
|
||||
else 1.0,
|
||||
noise_db_range=args.noise_db_range
|
||||
if hasattr(args, "noise_db_range")
|
||||
else "13_15",
|
||||
speech_volume_normalize=args.speech_volume_normalize
|
||||
if hasattr(args, "rir_scp")
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def required_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
if not inference:
|
||||
retval = ("speech", "profile", "label")
|
||||
else:
|
||||
# Recognition mode
|
||||
retval = ("speech", "profile")
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
|
||||
# 5. speaker encoder
|
||||
if getattr(args, "speaker_encoder", None) is not None:
|
||||
speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
|
||||
speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
|
||||
else:
|
||||
speaker_encoder = None
|
||||
|
||||
# 6. CI & CD scorer
|
||||
if getattr(args, "ci_scorer", None) is not None:
|
||||
ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
|
||||
ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
|
||||
else:
|
||||
ci_scorer = None
|
||||
|
||||
if getattr(args, "cd_scorer", None) is not None:
|
||||
cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
|
||||
cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
|
||||
else:
|
||||
cd_scorer = None
|
||||
|
||||
# 7. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(**args.decoder_conf)
|
||||
|
||||
if getattr(args, "label_aggregator", None) is not None:
|
||||
label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
|
||||
label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
|
||||
else:
|
||||
label_aggregator = None
|
||||
|
||||
# 9. Build model
|
||||
model_class = model_choices.get_class(args.model)
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
label_aggregator=label_aggregator,
|
||||
encoder=encoder,
|
||||
speaker_encoder=speaker_encoder,
|
||||
ci_scorer=ci_scorer,
|
||||
cd_scorer=cd_scorer,
|
||||
decoder=decoder,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 10. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
||||
@classmethod
|
||||
def build_model_from_file(
|
||||
cls,
|
||||
config_file: Union[Path, str] = None,
|
||||
model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Build model from the files.
|
||||
|
||||
This method is used for inference or fine-tuning.
|
||||
|
||||
Args:
|
||||
config_file: The yaml file saved when training.
|
||||
model_file: The model file saved when training.
|
||||
cmvn_file: The cmvn file for front-end
|
||||
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if config_file is None:
|
||||
assert model_file is not None, (
|
||||
"The argument 'model_file' must be provided "
|
||||
"if the argument 'config_file' is not specified."
|
||||
)
|
||||
config_file = Path(model_file).parent / "config.yaml"
|
||||
else:
|
||||
config_file = Path(config_file)
|
||||
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
if cmvn_file is not None:
|
||||
args["cmvn_file"] = cmvn_file
|
||||
args = argparse.Namespace(**args)
|
||||
model = cls.build_model(args)
|
||||
if not isinstance(model, AbsESPnetModel):
|
||||
raise RuntimeError(
|
||||
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
|
||||
)
|
||||
model.to(device)
|
||||
model_dict = dict()
|
||||
model_name_pth = None
|
||||
if model_file is not None:
|
||||
logging.info("model_file is {}".format(model_file))
|
||||
if device == "cuda":
|
||||
device = f"cuda:{torch.cuda.current_device()}"
|
||||
model_dir = os.path.dirname(model_file)
|
||||
model_name = os.path.basename(model_file)
|
||||
if "model.ckpt-" in model_name or ".bin" in model_name:
|
||||
if ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
|
||||
else:
|
||||
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
else:
|
||||
model_dict = cls.convert_tf2torch(model, model_file)
|
||||
model.load_state_dict(model_dict)
|
||||
else:
|
||||
model_dict = torch.load(model_file, map_location=device)
|
||||
model.load_state_dict(model_dict)
|
||||
if model_name_pth is not None and not os.path.exists(model_name_pth):
|
||||
torch.save(model_dict, model_name_pth)
|
||||
logging.info("model_file is saved to pth: {}".format(model_name_pth))
|
||||
|
||||
return model, args
|
||||
|
||||
@classmethod
|
||||
def convert_tf2torch(
|
||||
cls,
|
||||
model,
|
||||
ckpt,
|
||||
):
|
||||
logging.info("start convert tf model to torch model")
|
||||
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
|
||||
var_dict_tf = load_tf_dict(ckpt)
|
||||
var_dict_torch = model.state_dict()
|
||||
var_dict_torch_update = dict()
|
||||
# speech encoder
|
||||
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# speaker encoder
|
||||
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# cd scorer
|
||||
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# ci scorer
|
||||
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# decoder
|
||||
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
|
||||
return var_dict_torch_update
|
||||
103
funasr/utils/job_runner.py
Normal file
103
funasr/utils/job_runner.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import print_function
|
||||
from multiprocessing import Pool
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
|
||||
class MultiProcessRunner:
|
||||
def __init__(self, fn):
|
||||
self.args = None
|
||||
self.process = fn
|
||||
|
||||
def run(self):
|
||||
parser = argparse.ArgumentParser("")
|
||||
# Task-independent options
|
||||
parser.add_argument("--nj", type=int, default=16)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--no_pbar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_ture", default=False)
|
||||
|
||||
task_list, args = self.prepare(parser)
|
||||
result_list = self.pool_run(task_list, args)
|
||||
self.post(result_list, args)
|
||||
|
||||
def prepare(self, parser):
|
||||
raise NotImplementedError("Please implement the prepare function.")
|
||||
|
||||
def post(self, result_list, args):
|
||||
raise NotImplementedError("Please implement the post function.")
|
||||
|
||||
def pool_run(self, tasks, args):
|
||||
results = []
|
||||
if args.debug:
|
||||
one_result = self.process(tasks[0])
|
||||
results.append(one_result)
|
||||
else:
|
||||
pool = Pool(args.nj)
|
||||
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
|
||||
results.append(one_result)
|
||||
pool.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MultiProcessRunnerV2:
|
||||
def __init__(self, fn):
|
||||
self.args = None
|
||||
self.process = fn
|
||||
|
||||
def run(self):
|
||||
parser = argparse.ArgumentParser("")
|
||||
# Task-independent options
|
||||
parser.add_argument("--nj", type=int, default=16)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--no_pbar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_true", default=False)
|
||||
|
||||
task_list, args = self.prepare(parser)
|
||||
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
|
||||
if args.verbose:
|
||||
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
|
||||
subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)]
|
||||
result_list = self.pool_run(subtask_list, args)
|
||||
self.post(result_list, args)
|
||||
|
||||
def prepare(self, parser):
|
||||
raise NotImplementedError("Please implement the prepare function.")
|
||||
|
||||
def post(self, result_list, args):
|
||||
raise NotImplementedError("Please implement the post function.")
|
||||
|
||||
def pool_run(self, tasks, args):
|
||||
results = []
|
||||
if args.debug:
|
||||
one_result = self.process(tasks[0])
|
||||
results.append(one_result)
|
||||
else:
|
||||
pool = Pool(args.nj)
|
||||
for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
|
||||
results.append(one_result)
|
||||
pool.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MultiProcessRunnerV3(MultiProcessRunnerV2):
|
||||
def run(self):
|
||||
parser = argparse.ArgumentParser("")
|
||||
# Task-independent options
|
||||
parser.add_argument("--nj", type=int, default=16)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--no_pbar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_true", default=False)
|
||||
parser.add_argument("--sr", type=int, default=16000)
|
||||
|
||||
task_list, shared_param, args = self.prepare(parser)
|
||||
chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
|
||||
if args.verbose:
|
||||
print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
|
||||
subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
|
||||
for i in range(args.nj)]
|
||||
result_list = self.pool_run(subtask_list, args)
|
||||
self.post(result_list, args)
|
||||
48
funasr/utils/misc.py
Normal file
48
funasr/utils/misc.py
Normal file
@ -0,0 +1,48 @@
|
||||
import io
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def statistic_model_parameters(model, prefix=None):
|
||||
var_dict = model.state_dict()
|
||||
numel = 0
|
||||
for i, key in enumerate(sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))):
|
||||
if prefix is None or key.startswith(prefix):
|
||||
numel += var_dict[key].numel()
|
||||
return numel
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.int):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
def load_scp_as_dict(scp_path, value_type='str', kv_sep=" "):
|
||||
with io.open(scp_path, 'r', encoding='utf-8') as f:
|
||||
ret_dict = OrderedDict()
|
||||
for one_line in f.readlines():
|
||||
one_line = one_line.strip()
|
||||
pos = one_line.find(kv_sep)
|
||||
key, value = one_line[:pos], one_line[pos + 1:]
|
||||
if value_type == 'list':
|
||||
value = value.split(' ')
|
||||
ret_dict[key] = value
|
||||
return ret_dict
|
||||
|
||||
|
||||
def load_scp_as_list(scp_path, value_type='str', kv_sep=" "):
|
||||
with io.open(scp_path, 'r', encoding='utf8') as f:
|
||||
ret_dict = []
|
||||
for one_line in f.readlines():
|
||||
one_line = one_line.strip()
|
||||
pos = one_line.find(kv_sep)
|
||||
key, value = one_line[:pos], one_line[pos + 1:]
|
||||
if value_type == 'list':
|
||||
value = value.split(' ')
|
||||
ret_dict.append((key, value))
|
||||
return ret_dict
|
||||
Loading…
Reference in New Issue
Block a user