diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 5787f1dbf..c194179ab 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -7,6 +7,24 @@ import edit_distance from itertools import zip_longest +def cif_wo_hidden(alphas, threshold): + batch_size, len_time = alphas.size() + # loop varss + integrate = torch.zeros([batch_size], device=alphas.device) + # intermediate vars along time + list_fires = [] + for t in range(len_time): + alpha = alphas[:, t] + integrate += alpha + list_fires.append(integrate) + fire_place = integrate >= threshold + integrate = torch.where(fire_place, + integrate - torch.ones([batch_size], device=alphas.device), + integrate) + fires = torch.stack(list_fires, 1) + return fires + + def ts_prediction_lfr6_standard(us_alphas, us_peaks, char_list, @@ -20,25 +38,23 @@ def ts_prediction_lfr6_standard(us_alphas, MAX_TOKEN_DURATION = 12 TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled if len(us_alphas.shape) == 2: - _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only + alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only else: - _, peaks = us_alphas, us_peaks - num_frames = peaks.shape[0] + alphas, peaks = us_alphas, us_peaks if char_list[-1] == '': char_list = char_list[:-1] + fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset + if len(fire_place) != len(char_list) + 1: + alphas /= (alphas.sum() / (len(char_list) + 1)) + alphas = alphas.unsqueeze(0) + peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0] + fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset + num_frames = peaks.shape[0] timestamp_list = [] new_char_list = [] # for bicif model trained with large data, cif2 actually fires when a character starts # so treat the frames between two peaks as the duration of the former token fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset - num_peak = len(fire_place) - if num_peak != len(char_list) + 1: - logging.warning("length mismatch, result might be incorrect.") - logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1)) - if num_peak > len(char_list) + 1: - fire_place = fire_place[:len(char_list) - 1] - elif num_peak < len(char_list) + 1: - char_list = char_list[:num_peak + 1] # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 # begin silence if fire_place[0] > START_END_THRESHOLD: