mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Update timestamp_tools.py
[fix] fix a bug in function "ts_prediction_lfr6_standard"
This commit is contained in:
parent
8b03379434
commit
df768884a2
@ -43,18 +43,18 @@ def ts_prediction_lfr6_standard(us_alphas,
|
|||||||
alphas, peaks = us_alphas, us_peaks
|
alphas, peaks = us_alphas, us_peaks
|
||||||
if char_list[-1] == '</s>':
|
if char_list[-1] == '</s>':
|
||||||
char_list = char_list[:-1]
|
char_list = char_list[:-1]
|
||||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||||
if len(fire_place) != len(char_list) + 1:
|
if len(fire_place) != len(char_list) + 1:
|
||||||
alphas /= (alphas.sum() / (len(char_list) + 1))
|
alphas /= (alphas.sum() / (len(char_list) + 1))
|
||||||
alphas = alphas.unsqueeze(0)
|
alphas = alphas.unsqueeze(0)
|
||||||
peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
|
peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
|
||||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||||
num_frames = peaks.shape[0]
|
num_frames = peaks.shape[0]
|
||||||
timestamp_list = []
|
timestamp_list = []
|
||||||
new_char_list = []
|
new_char_list = []
|
||||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||||
# so treat the frames between two peaks as the duration of the former token
|
# so treat the frames between two peaks as the duration of the former token
|
||||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
# fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||||
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||||
# begin silence
|
# begin silence
|
||||||
if fire_place[0] > START_END_THRESHOLD:
|
if fire_place[0] > START_END_THRESHOLD:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user