mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix bug for timestamp inference
This commit is contained in:
parent
6a31591716
commit
e740ec08b7
@ -7,6 +7,24 @@ import edit_distance
|
||||
from itertools import zip_longest
|
||||
|
||||
|
||||
def cif_wo_hidden(alphas, threshold):
|
||||
batch_size, len_time = alphas.size()
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], device=alphas.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device),
|
||||
integrate)
|
||||
fires = torch.stack(list_fires, 1)
|
||||
return fires
|
||||
|
||||
|
||||
def ts_prediction_lfr6_standard(us_alphas,
|
||||
us_peaks,
|
||||
char_list,
|
||||
@ -20,25 +38,23 @@ def ts_prediction_lfr6_standard(us_alphas,
|
||||
MAX_TOKEN_DURATION = 12
|
||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||
if len(us_alphas.shape) == 2:
|
||||
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||
else:
|
||||
_, peaks = us_alphas, us_peaks
|
||||
num_frames = peaks.shape[0]
|
||||
alphas, peaks = us_alphas, us_peaks
|
||||
if char_list[-1] == '</s>':
|
||||
char_list = char_list[:-1]
|
||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||
if len(fire_place) != len(char_list) + 1:
|
||||
alphas /= (alphas.sum() / (len(char_list) + 1))
|
||||
alphas = alphas.unsqueeze(0)
|
||||
peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
|
||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||
num_frames = peaks.shape[0]
|
||||
timestamp_list = []
|
||||
new_char_list = []
|
||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||
# so treat the frames between two peaks as the duration of the former token
|
||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||
num_peak = len(fire_place)
|
||||
if num_peak != len(char_list) + 1:
|
||||
logging.warning("length mismatch, result might be incorrect.")
|
||||
logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
|
||||
if num_peak > len(char_list) + 1:
|
||||
fire_place = fire_place[:len(char_list) - 1]
|
||||
elif num_peak < len(char_list) + 1:
|
||||
char_list = char_list[:num_peak + 1]
|
||||
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||
# begin silence
|
||||
if fire_place[0] > START_END_THRESHOLD:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user