from itertools import zip_longest import torch import copy import codecs import logging import edit_distance import argparse import pdb import numpy as np from typing import Any, List, Tuple, Union def ts_prediction_lfr6_standard(us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True ): if not len(char_list): return "", [] START_END_THRESHOLD = 5 MAX_TOKEN_DURATION = 12 TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled if len(us_alphas.shape) == 2: _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only else: _, peaks = us_alphas, us_peaks num_frames = peaks.shape[0] if char_list[-1] == '': char_list = char_list[:-1] timestamp_list = [] new_char_list = [] # for bicif model trained with large data, cif2 actually fires when a character starts # so treat the frames between two peaks as the duration of the former token fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset num_peak = len(fire_place) assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 # begin silence if fire_place[0] > START_END_THRESHOLD: # char_list.insert(0, '') timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) new_char_list.append('') # tokens timestamp for i in range(len(fire_place)-1): new_char_list.append(char_list[i]) if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION: timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE]) else: # cut the duration to token and sil of the 0-weight frames last long _split = fire_place[i] + MAX_TOKEN_DURATION timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE]) timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE]) new_char_list.append('') # tail token and end silence # new_char_list.append(char_list[-1]) if num_frames - fire_place[-1] > START_END_THRESHOLD: _end = (num_frames + fire_place[-1]) * 0.5 # _end = fire_place[-1] timestamp_list[-1][1] = _end*TIME_RATE timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) new_char_list.append("") else: timestamp_list[-1][1] = num_frames*TIME_RATE if vad_offset: # add offset time in model with vad for i in range(len(timestamp_list)): timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0 timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0 res_txt = "" for char, timestamp in zip(new_char_list, timestamp_list): #if char != '': if not sil_in_str and char == '': continue res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5]) res = [] for char, timestamp in zip(new_char_list, timestamp_list): if char != '': res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) return res_txt, res def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed): res = [] if text_postprocessed is None: return res if time_stamp_postprocessed is None: return res if len(time_stamp_postprocessed) == 0: return res if len(text_postprocessed) == 0: return res if punc_id_list is None or len(punc_id_list) == 0: res.append({ 'text': text_postprocessed.split(), "start": time_stamp_postprocessed[0][0], "end": time_stamp_postprocessed[-1][1] }) return res if len(punc_id_list) != len(time_stamp_postprocessed): print(" warning length mistach!!!!!!") sentence_text = '' sentence_start = time_stamp_postprocessed[0][0] sentence_end = time_stamp_postprocessed[0][1] texts = text_postprocessed.split() punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None)) for punc_stamp_text in punc_stamp_text_list: punc_id, time_stamp, text = punc_stamp_text sentence_text += text if text is not None else '' punc_id = int(punc_id) if punc_id is not None else 1 sentence_end = time_stamp[1] if time_stamp is not None else sentence_end if punc_id == 2: sentence_text += ',' res.append({ 'text': sentence_text, "start": sentence_start, "end": sentence_end }) sentence_text = '' sentence_start = sentence_end elif punc_id == 3: sentence_text += '.' res.append({ 'text': sentence_text, "start": sentence_start, "end": sentence_end }) sentence_text = '' sentence_start = sentence_end elif punc_id == 4: sentence_text += '?' res.append({ 'text': sentence_text, "start": sentence_start, "end": sentence_end }) sentence_text = '' sentence_start = sentence_end return res class AverageShiftCalculator(): def __init__(self): logging.warning("Calculating average shift.") def __call__(self, file1, file2): uttid_list1, ts_dict1 = self.read_timestamps(file1) uttid_list2, ts_dict2 = self.read_timestamps(file2) uttid_intersection = self._intersection(uttid_list1, uttid_list2) res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2) logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8])) logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid)) def _intersection(self, list1, list2): set1 = set(list1) set2 = set(list2) if set1 == set2: logging.warning("Uttid same checked.") return set1 itsc = list(set1 & set2) logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc))) return itsc def read_timestamps(self, file): # read timestamps file in standard format uttid_list = [] ts_dict = {} with codecs.open(file, 'r') as fin: for line in fin.readlines(): text = '' ts_list = [] line = line.rstrip() uttid = line.split()[0] uttid_list.append(uttid) body = " ".join(line.split()[1:]) for pd in body.split(';'): if not len(pd): continue # pdb.set_trace() char, start, end = pd.lstrip(" ").split(' ') text += char + ',' ts_list.append((float(start), float(end))) # ts_lists.append(ts_list) ts_dict[uttid] = (text[:-1], ts_list) logging.warning("File {} read done.".format(file)) return uttid_list, ts_dict def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2): shift_time = 0 for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2): shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1]) num_tokens = len(filtered_timestamp_list1) return shift_time, num_tokens def as_cal(self, uttid_list, ts_dict1, ts_dict2): # calculate average shift between timestamp1 and timestamp2 # when characters differ, use edit distance alignment # and calculate the error between the same characters self._accumlated_shift = 0 self._accumlated_tokens = 0 self.max_shift = 0 self.max_shift_uttid = None for uttid in uttid_list: (t1, ts1) = ts_dict1[uttid] (t2, ts2) = ts_dict2[uttid] _align, _align2, _align3 = [], [], [] fts1, fts2 = [], [] _t1, _t2 = [], [] sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(',')) s = sm.get_opcodes() for j in range(len(s)): if s[j][0] == "replace" or s[j][0] == "insert": _align.append(0) if s[j][0] == "replace" or s[j][0] == "delete": _align3.append(0) elif s[j][0] == "equal": _align.append(1) _align3.append(1) else: continue # use s to index t2 for a, ts , t in zip(_align, ts2, t2.split(',')): if a: fts2.append(ts) _t2.append(t) sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(',')) s = sm2.get_opcodes() for j in range(len(s)): if s[j][0] == "replace" or s[j][0] == "insert": _align2.append(0) elif s[j][0] == "equal": _align2.append(1) else: continue # use s2 tp index t1 for a, ts, t in zip(_align3, ts1, t1.split(',')): if a: fts1.append(ts) _t1.append(t) if len(fts1) == len(fts2): shift_time, num_tokens = self._shift(fts1, fts2) self._accumlated_shift += shift_time self._accumlated_tokens += num_tokens if shift_time/num_tokens > self.max_shift: self.max_shift = shift_time/num_tokens self.max_shift_uttid = uttid else: logging.warning("length mismatch") return self._accumlated_shift / self._accumlated_tokens def convert_external_alphas(alphas_file, text_file, output_file): from funasr.models.predictor.cif import cif_wo_hidden with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3: for line1, line2 in zip(f1.readlines(), f2.readlines()): line1 = line1.rstrip() line2 = line2.rstrip() assert line1.split()[0] == line2.split()[0] uttid = line1.split()[0] alphas = [float(i) for i in line1.split()[1:]] new_alphas = np.array(remove_chunk_padding(alphas)) new_alphas[-1] += 1e-4 text = line2.split()[1:] if len(text) + 1 != int(new_alphas.sum()): # force resize new_alphas *= (len(text) + 1) / int(new_alphas.sum()) peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4) if " " in text: text = text.split() else: text = [i for i in text] res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text, force_time_shift=-7.0, sil_in_str=False) f3.write("{} {}\n".format(uttid, res_str)) def remove_chunk_padding(alphas): # remove the padding part in alphas if using chunk paraformer for GPU START_ZERO = 45 MID_ZERO = 75 REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5 alphas = alphas[START_ZERO:] # remove the padding at beginning new_alphas = [] while True: new_alphas = new_alphas + alphas[:REAL_FRAMES] alphas = alphas[REAL_FRAMES+MID_ZERO:] if len(alphas) < REAL_FRAMES: break return new_alphas SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas'] def main(args): if args.mode == 'cal_aas': asc = AverageShiftCalculator() asc(args.input, args.input2) elif args.mode == 'read_ext_alphas': convert_external_alphas(args.input, args.input2, args.output) else: logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='timestamp tools') parser.add_argument('--mode', default=None, type=str, choices=SUPPORTED_MODES, help='timestamp related toolbox') parser.add_argument('--input', default=None, type=str, help='input file path') parser.add_argument('--output', default=None, type=str, help='output file name') parser.add_argument('--input2', default=None, type=str, help='input2 file path') parser.add_argument('--kaldi-ts-type', default='v2', type=str, choices=['v0', 'v1', 'v2'], help='kaldi timestamp to write') args = parser.parse_args() main(args)