FunASR/funasr/utils/timestamp_tools.py
2023-07-20 18:57:16 +08:00

325 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import codecs
import logging
import argparse
import numpy as np
import edit_distance
from itertools import zip_longest
def ts_prediction_lfr6_standard(us_alphas,
us_peaks,
char_list,
vad_offset=0.0,
force_time_shift=-1.5,
sil_in_str=True
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
if len(us_alphas.shape) == 2:
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
_, peaks = us_alphas, us_peaks
num_frames = peaks.shape[0]
if char_list[-1] == '</s>':
char_list = char_list[:-1]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
num_peak = len(fire_place)
if num_peak != len(char_list) + 1:
logging.warning("length mismatch, result might be incorrect.")
logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
if num_peak > len(char_list) + 1:
fire_place = fire_place[:len(char_list) - 1]
elif num_peak < len(char_list) + 1:
char_list = char_list[:num_peak + 1]
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
new_char_list.append('<sil>')
# tokens timestamp
for i in range(len(fire_place)-1):
new_char_list.append(char_list[i])
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
new_char_list.append('<sil>')
# tail token and end silence
# new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) * 0.5
# _end = fire_place[-1]
timestamp_list[-1][1] = _end*TIME_RATE
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames*TIME_RATE
if vad_offset: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(new_char_list, timestamp_list):
#if char != '<sil>':
if not sil_in_str and char == '<sil>': continue
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != '<sil>':
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
punc_list = ['', '', '', '']
res = []
if text_postprocessed is None:
return res
if time_stamp_postprocessed is None:
return res
if len(time_stamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append({
'text': text_postprocessed.split(),
"start": time_stamp_postprocessed[0][0],
"end": time_stamp_postprocessed[-1][1],
'text_seg': text_postprocessed.split(),
"ts_list": time_stamp_postprocessed,
})
return res
if len(punc_id_list) != len(time_stamp_postprocessed):
print(" warning length mistach!!!!!!")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = time_stamp_postprocessed[0][0]
sentence_end = time_stamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
for punc_stamp_text in punc_stamp_text_list:
punc_id, time_stamp, text = punc_stamp_text
# sentence_text += text if text is not None else ''
if text is not None:
if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
sentence_text += ' ' + text
elif len(sentence_text) and ('a' <= sentence_text[-1] <= 'z' or 'A' <= sentence_text[-1] <= 'Z'):
sentence_text += ' ' + text
else:
sentence_text += text
sentence_text_seg += text + ' '
ts_list.append(time_stamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
if punc_id > 1:
sentence_text += punc_list[punc_id - 2]
res.append({
'text': sentence_text,
"start": sentence_start,
"end": sentence_end,
"text_seg": sentence_text_seg,
"ts_list": ts_list
})
sentence_text = ''
sentence_text_seg = ''
ts_list = []
sentence_start = sentence_end
return res
class AverageShiftCalculator():
def __init__(self):
logging.warning("Calculating average shift.")
def __call__(self, file1, file2):
uttid_list1, ts_dict1 = self.read_timestamps(file1)
uttid_list2, ts_dict2 = self.read_timestamps(file2)
uttid_intersection = self._intersection(uttid_list1, uttid_list2)
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
def _intersection(self, list1, list2):
set1 = set(list1)
set2 = set(list2)
if set1 == set2:
logging.warning("Uttid same checked.")
return set1
itsc = list(set1 & set2)
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
return itsc
def read_timestamps(self, file):
# read timestamps file in standard format
uttid_list = []
ts_dict = {}
with codecs.open(file, 'r') as fin:
for line in fin.readlines():
text = ''
ts_list = []
line = line.rstrip()
uttid = line.split()[0]
uttid_list.append(uttid)
body = " ".join(line.split()[1:])
for pd in body.split(';'):
if not len(pd): continue
# pdb.set_trace()
char, start, end = pd.lstrip(" ").split(' ')
text += char + ','
ts_list.append((float(start), float(end)))
# ts_lists.append(ts_list)
ts_dict[uttid] = (text[:-1], ts_list)
logging.warning("File {} read done.".format(file))
return uttid_list, ts_dict
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
shift_time = 0
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
num_tokens = len(filtered_timestamp_list1)
return shift_time, num_tokens
def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# calculate average shift between timestamp1 and timestamp2
# when characters differ, use edit distance alignment
# and calculate the error between the same characters
self._accumlated_shift = 0
self._accumlated_tokens = 0
self.max_shift = 0
self.max_shift_uttid = None
for uttid in uttid_list:
(t1, ts1) = ts_dict1[uttid]
(t2, ts2) = ts_dict2[uttid]
_align, _align2, _align3 = [], [], []
fts1, fts2 = [], []
_t1, _t2 = [], []
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
s = sm.get_opcodes()
for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert":
_align.append(0)
if s[j][0] == "replace" or s[j][0] == "delete":
_align3.append(0)
elif s[j][0] == "equal":
_align.append(1)
_align3.append(1)
else:
continue
# use s to index t2
for a, ts , t in zip(_align, ts2, t2.split(',')):
if a:
fts2.append(ts)
_t2.append(t)
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
s = sm2.get_opcodes()
for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert":
_align2.append(0)
elif s[j][0] == "equal":
_align2.append(1)
else:
continue
# use s2 tp index t1
for a, ts, t in zip(_align3, ts1, t1.split(',')):
if a:
fts1.append(ts)
_t1.append(t)
if len(fts1) == len(fts2):
shift_time, num_tokens = self._shift(fts1, fts2)
self._accumlated_shift += shift_time
self._accumlated_tokens += num_tokens
if shift_time/num_tokens > self.max_shift:
self.max_shift = shift_time/num_tokens
self.max_shift_uttid = uttid
else:
logging.warning("length mismatch")
return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
from funasr.models.predictor.cif import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
line2 = line2.rstrip()
assert line1.split()[0] == line2.split()[0]
uttid = line1.split()[0]
alphas = [float(i) for i in line1.split()[1:]]
new_alphas = np.array(remove_chunk_padding(alphas))
new_alphas[-1] += 1e-4
text = line2.split()[1:]
if len(text) + 1 != int(new_alphas.sum()):
# force resize
new_alphas *= (len(text) + 1) / int(new_alphas.sum())
peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
if " " in text:
text = text.split()
else:
text = [i for i in text]
res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
force_time_shift=-7.0,
sil_in_str=False)
f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
# remove the padding part in alphas if using chunk paraformer for GPU
START_ZERO = 45
MID_ZERO = 75
REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
alphas = alphas[START_ZERO:] # remove the padding at beginning
new_alphas = []
while True:
new_alphas = new_alphas + alphas[:REAL_FRAMES]
alphas = alphas[REAL_FRAMES+MID_ZERO:]
if len(alphas) < REAL_FRAMES: break
return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
if args.mode == 'cal_aas':
asc = AverageShiftCalculator()
asc(args.input, args.input2)
elif args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='timestamp tools')
parser.add_argument('--mode',
default=None,
type=str,
choices=SUPPORTED_MODES,
help='timestamp related toolbox')
parser.add_argument('--input', default=None, type=str, help='input file path')
parser.add_argument('--output', default=None, type=str, help='output file name')
parser.add_argument('--input2', default=None, type=str, help='input2 file path')
parser.add_argument('--kaldi-ts-type',
default='v2',
type=str,
choices=['v0', 'v1', 'v2'],
help='kaldi timestamp to write')
args = parser.parse_args()
main(args)