FunASR/egs/aishell2/transformer/utils/combine_cmvn_file.py
2023-05-12 17:25:54 +08:00

73 lines
1.6 KiB
Python
Executable File

import argparse
import json
import os
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="combine cmvn file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dim",
default=80,
type=int,
help="feature dim",
)
parser.add_argument(
"--cmvn_dir",
default=False,
required=True,
type=str,
help="cmvn dir",
)
parser.add_argument(
"--nj",
default=1,
required=True,
type=int,
help="num of cmvn files",
)
parser.add_argument(
"--output_dir",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
total_means = np.zeros(args.dim)
total_vars = np.zeros(args.dim)
total_frames = 0
cmvn_file = os.path.join(args.output_dir, "cmvn.json")
for i in range(1, args.nj + 1):
with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
total_vars += np.array(cmvn_stats["var_stats"])
total_frames += cmvn_stats["total_frames"]
cmvn_info = {
'mean_stats': list(total_means.tolist()),
'var_stats': list(total_vars.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()