mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
import editdistance
|
|
import sys
|
|
import os
|
|
from itertools import permutations
|
|
|
|
|
|
def load_transcripts(file_path):
|
|
trans_list = []
|
|
for one_line in open(file_path, "rt"):
|
|
meeting_id, trans = one_line.strip().split(" ")
|
|
trans_list.append((meeting_id.strip(), trans.strip()))
|
|
|
|
return trans_list
|
|
|
|
def calc_spk_trans(trans):
|
|
spk_trans_ = [x.strip() for x in trans.split("$")]
|
|
spk_trans = []
|
|
for i in range(len(spk_trans_)):
|
|
spk_trans.append((str(i), spk_trans_[i]))
|
|
return spk_trans
|
|
|
|
def calc_cer(ref_trans, hyp_trans):
|
|
ref_spk_trans = calc_spk_trans(ref_trans)
|
|
hyp_spk_trans = calc_spk_trans(hyp_trans)
|
|
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
|
|
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
|
|
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
|
|
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
|
|
|
|
errors, counts, permutes = [], [], []
|
|
min_error = 0
|
|
cost_dict = {}
|
|
for perm in permutations(range(num_spk)):
|
|
flag = True
|
|
p_err, p_count = 0, 0
|
|
for idx, p in enumerate(perm):
|
|
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
|
|
flag = False
|
|
break
|
|
cost_key = "{}-{}".format(idx, p)
|
|
if cost_key in cost_dict:
|
|
_e = cost_dict[cost_key]
|
|
else:
|
|
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
|
|
cost_dict[cost_key] = _e
|
|
if _e > min_error > 0:
|
|
flag = False
|
|
break
|
|
p_err += _e
|
|
p_count += len(ref_spk_trans[idx][1])
|
|
|
|
if flag:
|
|
if p_err < min_error or min_error == 0:
|
|
min_error = p_err
|
|
|
|
errors.append(p_err)
|
|
counts.append(p_count)
|
|
permutes.append(perm)
|
|
|
|
sd_cer = [(err, cnt, err/cnt, permute)
|
|
for err, cnt, permute in zip(errors, counts, permutes)]
|
|
# import ipdb;ipdb.set_trace()
|
|
best_rst = min(sd_cer, key=lambda x: x[2])
|
|
|
|
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
|
|
|
|
|
|
def main():
|
|
ref=sys.argv[1]
|
|
hyp=sys.argv[2]
|
|
result_path=sys.argv[3]
|
|
ref_list = load_transcripts(ref)
|
|
hyp_list = load_transcripts(hyp)
|
|
result_file = open(result_path,'w')
|
|
error, count = 0, 0
|
|
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
|
|
assert ref_id == hyp_id
|
|
mid = ref_id
|
|
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
|
|
error, count = error + dist, count + length
|
|
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
|
|
|
|
# print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
|
|
|
|
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
|
|
result_file.close()
|
|
# print("Sum/Avg: {:.2f}".format(error / count * 100.0))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|