diff --git a/egs/aishell/transformer/utils/cmvn-converter.py b/egs/aishell/transformer/utils/cmvn-converter.py new file mode 100644 index 000000000..97b398157 --- /dev/null +++ b/egs/aishell/transformer/utils/cmvn-converter.py @@ -0,0 +1,54 @@ +import argparse +import json +import numpy as np + + +def get_parser(): + parser = argparse.ArgumentParser( + description="cmvn converter", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--cmvn-json", + "-c", + default=False, + required=True, + type=str, + help="cmvn json file", + ) + parser.add_argument( + "--am-mvn", + "-a", + default=False, + required=True, + type=str, + help="am mvn file", + ) + return parser + +def main(): + parser = get_parser() + args = parser.parse_args() + + with open(args.cmvn_json, "r") as fin: + cmvn_dict = json.load(fin) + + mean_stats = np.array(cmvn_dict["mean_stats"]) + var_stats = np.array(cmvn_dict["var_stats"]) + total_frame = np.array(cmvn_dict["total_frames"]) + + print(mean_stats.dtype) + mean = -1.0 * mean_stats / total_frame + var = 1.0 / np.sqrt(var_stats / total_frame - mean * mean) + dims = mean.shape[0] + with open(args.am_mvn, 'w') as fout: + fout.write("" + "\n" + " " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + " " + str(dims) + " " + str(dims) + "\n") + mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]') + fout.write(" 0 " + mean_str + '\n') + fout.write(" " + str(dims) + " " + str(dims) + '\n') + var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]') + fout.write(" 0 " + var_str + '\n') + fout.write("" + '\n') + +if __name__ == '__main__': + main()