mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
commit
6f18b5619a
@ -1,6 +1,10 @@
|
||||
import torch
|
||||
import copy
|
||||
import codecs
|
||||
import logging
|
||||
import edit_distance
|
||||
import argparse
|
||||
import pdb
|
||||
import numpy as np
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
@ -9,7 +13,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
||||
us_peaks,
|
||||
char_list,
|
||||
vad_offset=0.0,
|
||||
force_time_shift=-1.5
|
||||
force_time_shift=-1.5,
|
||||
sil_in_str=True
|
||||
):
|
||||
if not len(char_list):
|
||||
return []
|
||||
@ -62,6 +67,8 @@ def ts_prediction_lfr6_standard(us_alphas,
|
||||
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||
res_txt = ""
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
#if char != '<sil>':
|
||||
if not sil_in_str and char == '<sil>': continue
|
||||
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||
res = []
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
@ -121,4 +128,181 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
|
||||
return res
|
||||
|
||||
|
||||
class AverageShiftCalculator():
|
||||
def __init__(self):
|
||||
logging.warning("Calculating average shift.")
|
||||
def __call__(self, file1, file2):
|
||||
uttid_list1, ts_dict1 = self.read_timestamps(file1)
|
||||
uttid_list2, ts_dict2 = self.read_timestamps(file2)
|
||||
uttid_intersection = self._intersection(uttid_list1, uttid_list2)
|
||||
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
|
||||
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
|
||||
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
|
||||
|
||||
def _intersection(self, list1, list2):
|
||||
set1 = set(list1)
|
||||
set2 = set(list2)
|
||||
if set1 == set2:
|
||||
logging.warning("Uttid same checked.")
|
||||
return set1
|
||||
itsc = list(set1 & set2)
|
||||
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
|
||||
return itsc
|
||||
|
||||
def read_timestamps(self, file):
|
||||
# read timestamps file in standard format
|
||||
uttid_list = []
|
||||
ts_dict = {}
|
||||
with codecs.open(file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
text = ''
|
||||
ts_list = []
|
||||
line = line.rstrip()
|
||||
uttid = line.split()[0]
|
||||
uttid_list.append(uttid)
|
||||
body = " ".join(line.split()[1:])
|
||||
for pd in body.split(';'):
|
||||
if not len(pd): continue
|
||||
# pdb.set_trace()
|
||||
char, start, end = pd.lstrip(" ").split(' ')
|
||||
text += char + ','
|
||||
ts_list.append((float(start), float(end)))
|
||||
# ts_lists.append(ts_list)
|
||||
ts_dict[uttid] = (text[:-1], ts_list)
|
||||
logging.warning("File {} read done.".format(file))
|
||||
return uttid_list, ts_dict
|
||||
|
||||
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
shift_time = 0
|
||||
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
|
||||
num_tokens = len(filtered_timestamp_list1)
|
||||
return shift_time, num_tokens
|
||||
|
||||
def as_cal(self, uttid_list, ts_dict1, ts_dict2):
|
||||
# calculate average shift between timestamp1 and timestamp2
|
||||
# when characters differ, use edit distance alignment
|
||||
# and calculate the error between the same characters
|
||||
self._accumlated_shift = 0
|
||||
self._accumlated_tokens = 0
|
||||
self.max_shift = 0
|
||||
self.max_shift_uttid = None
|
||||
for uttid in uttid_list:
|
||||
(t1, ts1) = ts_dict1[uttid]
|
||||
(t2, ts2) = ts_dict2[uttid]
|
||||
_align, _align2, _align3 = [], [], []
|
||||
fts1, fts2 = [], []
|
||||
_t1, _t2 = [], []
|
||||
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
|
||||
s = sm.get_opcodes()
|
||||
for j in range(len(s)):
|
||||
if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
_align.append(0)
|
||||
if s[j][0] == "replace" or s[j][0] == "delete":
|
||||
_align3.append(0)
|
||||
elif s[j][0] == "equal":
|
||||
_align.append(1)
|
||||
_align3.append(1)
|
||||
else:
|
||||
continue
|
||||
# use s to index t2
|
||||
for a, ts , t in zip(_align, ts2, t2.split(',')):
|
||||
if a:
|
||||
fts2.append(ts)
|
||||
_t2.append(t)
|
||||
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
|
||||
s = sm2.get_opcodes()
|
||||
for j in range(len(s)):
|
||||
if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
_align2.append(0)
|
||||
elif s[j][0] == "equal":
|
||||
_align2.append(1)
|
||||
else:
|
||||
continue
|
||||
# use s2 tp index t1
|
||||
for a, ts, t in zip(_align3, ts1, t1.split(',')):
|
||||
if a:
|
||||
fts1.append(ts)
|
||||
_t1.append(t)
|
||||
if len(fts1) == len(fts2):
|
||||
shift_time, num_tokens = self._shift(fts1, fts2)
|
||||
self._accumlated_shift += shift_time
|
||||
self._accumlated_tokens += num_tokens
|
||||
if shift_time/num_tokens > self.max_shift:
|
||||
self.max_shift = shift_time/num_tokens
|
||||
self.max_shift_uttid = uttid
|
||||
else:
|
||||
logging.warning("length mismatch")
|
||||
return self._accumlated_shift / self._accumlated_tokens
|
||||
|
||||
|
||||
def convert_external_alphas(alphas_file, text_file, output_file):
|
||||
from funasr.models.predictor.cif import cif_wo_hidden
|
||||
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
|
||||
for line1, line2 in zip(f1.readlines(), f2.readlines()):
|
||||
line1 = line1.rstrip()
|
||||
line2 = line2.rstrip()
|
||||
assert line1.split()[0] == line2.split()[0]
|
||||
uttid = line1.split()[0]
|
||||
alphas = [float(i) for i in line1.split()[1:]]
|
||||
new_alphas = np.array(remove_chunk_padding(alphas))
|
||||
new_alphas[-1] += 1e-4
|
||||
text = line2.split()[1:]
|
||||
if len(text) + 1 != int(new_alphas.sum()):
|
||||
# force resize
|
||||
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
|
||||
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
|
||||
if " " in text:
|
||||
text = text.split()
|
||||
else:
|
||||
text = [i for i in text]
|
||||
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
|
||||
force_time_shift=-7.0,
|
||||
sil_in_str=False)
|
||||
f3.write("{} {}\n".format(uttid, res_str))
|
||||
|
||||
|
||||
def remove_chunk_padding(alphas):
|
||||
# remove the padding part in alphas if using chunk paraformer for GPU
|
||||
START_ZERO = 45
|
||||
MID_ZERO = 75
|
||||
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
|
||||
alphas = alphas[START_ZERO:] # remove the padding at beginning
|
||||
new_alphas = []
|
||||
while True:
|
||||
new_alphas = new_alphas + alphas[:REAL_FRAMES]
|
||||
alphas = alphas[REAL_FRAMES+MID_ZERO:]
|
||||
if len(alphas) < REAL_FRAMES: break
|
||||
return new_alphas
|
||||
|
||||
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mode == 'cal_aas':
|
||||
asc = AverageShiftCalculator()
|
||||
asc(args.input, args.input2)
|
||||
elif args.mode == 'read_ext_alphas':
|
||||
convert_external_alphas(args.input, args.input2, args.output)
|
||||
else:
|
||||
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='timestamp tools')
|
||||
parser.add_argument('--mode',
|
||||
default=None,
|
||||
type=str,
|
||||
choices=SUPPORTED_MODES,
|
||||
help='timestamp related toolbox')
|
||||
parser.add_argument('--input', default=None, type=str, help='input file path')
|
||||
parser.add_argument('--output', default=None, type=str, help='output file name')
|
||||
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
|
||||
parser.add_argument('--kaldi-ts-type',
|
||||
default='v2',
|
||||
type=str,
|
||||
choices=['v0', 'v1', 'v2'],
|
||||
help='kaldi timestamp to write')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user