FunASR/egs/alimeeting/sa_asr/local/compute_cpcer.py
yhliang e8528b8f62
Dev lyh (#645)
* update

* update

* fix bug

* fix bug
2023-06-16 20:16:47 +08:00

92 lines
3.0 KiB
Python

import editdistance
import sys
import os
from itertools import permutations
def load_transcripts(file_path):
trans_list = []
for one_line in open(file_path, "rt"):
meeting_id, trans = one_line.strip().split(" ")
trans_list.append((meeting_id.strip(), trans.strip()))
return trans_list
def calc_spk_trans(trans):
spk_trans_ = [x.strip() for x in trans.split("$")]
spk_trans = []
for i in range(len(spk_trans_)):
spk_trans.append((str(i), spk_trans_[i]))
return spk_trans
def calc_cer(ref_trans, hyp_trans):
ref_spk_trans = calc_spk_trans(ref_trans)
hyp_spk_trans = calc_spk_trans(hyp_trans)
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
errors, counts, permutes = [], [], []
min_error = 0
cost_dict = {}
for perm in permutations(range(num_spk)):
flag = True
p_err, p_count = 0, 0
for idx, p in enumerate(perm):
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
flag = False
break
cost_key = "{}-{}".format(idx, p)
if cost_key in cost_dict:
_e = cost_dict[cost_key]
else:
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
cost_dict[cost_key] = _e
if _e > min_error > 0:
flag = False
break
p_err += _e
p_count += len(ref_spk_trans[idx][1])
if flag:
if p_err < min_error or min_error == 0:
min_error = p_err
errors.append(p_err)
counts.append(p_count)
permutes.append(perm)
sd_cer = [(err, cnt, err/cnt, permute)
for err, cnt, permute in zip(errors, counts, permutes)]
# import ipdb;ipdb.set_trace()
best_rst = min(sd_cer, key=lambda x: x[2])
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
def main():
ref=sys.argv[1]
hyp=sys.argv[2]
result_path=sys.argv[3]
ref_list = load_transcripts(ref)
hyp_list = load_transcripts(hyp)
result_file = open(result_path,'w')
error, count = 0, 0
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
assert ref_id == hyp_id
mid = ref_id
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
error, count = error + dist, count + length
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
# print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
result_file.close()
# print("Sum/Avg: {:.2f}".format(error / count * 100.0))
if __name__ == '__main__':
main()