From 546262a0c61abf30c6861ff40f9a32f39b91baaa Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Thu, 16 Feb 2023 15:22:14 +0800 Subject: [PATCH] remove useless code --- funasr/bin/asr_inference_paraformer_vad.py | 1 - .../bin/asr_inference_paraformer_vad_punc.py | 12 +-- funasr/utils/timestamp_tools.py | 82 ------------------- 3 files changed, 4 insertions(+), 91 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py index dbb271986..c01c6ba5e 100644 --- a/funasr/bin/asr_inference_paraformer_vad.py +++ b/funasr/bin/asr_inference_paraformer_vad.py @@ -38,7 +38,6 @@ from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask -from funasr.utils.timestamp_tools import time_stamp_lfr6 from funasr.bin.punctuation_infer import Text2Punc from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index c4bb61bd1..755cc9cca 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -39,7 +39,7 @@ from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask -from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl +from funasr.utils.timestamp_tools import time_stamp_lfr6_pl from funasr.bin.punctuation_infer import Text2Punc from funasr.models.e2e_asr_paraformer import BiCifParaformer @@ -282,12 +282,8 @@ class Speech2Text: else: text = None - if isinstance(self.asr_model, BiCifParaformer): - timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) - results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) - else: - time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time) - results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) + timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) # assert check_return_type(results) return results @@ -617,7 +613,7 @@ def inference_modelscope( result = result_segments[0] text, token, token_int = result[0], result[1], result[2] time_stamp = None if len(result) < 4 else result[3] - + if use_timestamp and time_stamp is not None: postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) else: diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 33d1255cc..f966aeee9 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -4,88 +4,6 @@ import logging import numpy as np from typing import Any, List, Tuple, Union -def cut_interval(alphas: torch.Tensor, start: int, end: int, tail: bool): - if not tail: - if end == start + 1: - cut = (end + start) / 2.0 - else: - alpha = alphas[start+1: end].tolist() - reverse_steps = 1 - for reverse_alpha in alpha[::-1]: - if reverse_alpha > 0.35: - reverse_steps += 1 - else: - break - cut = end - reverse_steps - else: - if end != len(alphas) - 1: - cut = end + 1 - else: - cut = start + 1 - return float(cut) - -def time_stamp_lfr6(alphas: torch.Tensor, speech_lengths: torch.Tensor, raw_text: List[str], begin: int = 0, end: int = None): - time_stamp_list = [] - alphas = alphas[0] - text = copy.deepcopy(raw_text) - if end is None: - time = speech_lengths * 60 / 1000 - sacle_rate = (time / speech_lengths[0]).tolist() - else: - time = (end - begin) / 1000 - sacle_rate = (time / speech_lengths[0]).tolist() - - predictor = (alphas > 0.5).int() - fire_places = torch.nonzero(predictor == 1).squeeze(1).tolist() - - cuts = [] - npeak = int(predictor.sum()) - nchar = len(raw_text) - if npeak - 1 == nchar: - fire_places = torch.where((alphas > 0.5) == 1)[0].tolist() - for i in range(len(fire_places)): - if fire_places[i] < len(alphas) - 1: - if 0.05 < alphas[fire_places[i]+1] < 0.5: - fire_places[i] += 1 - elif npeak < nchar: - lost_num = nchar - npeak - lost_fire = speech_lengths[0].tolist() - fire_places[-1] - interval_distance = lost_fire // (lost_num + 1) - for i in range(1, lost_num + 1): - fire_places.append(fire_places[-1] + interval_distance) - elif npeak - 1 > nchar: - redundance_num = npeak - 1 - nchar - for i in range(redundance_num): - fire_places.pop() - - cuts.append(0) - start_sil = True - if start_sil: - text.insert(0, '') - - for i in range(len(fire_places)-1): - cuts.append(cut_interval(alphas, fire_places[i], fire_places[i+1], tail=(i==len(fire_places)-2))) - - for i in range(2, len(fire_places)-2): - if fire_places[i-2] == fire_places[i-1] - 1 and fire_places[i-1] != fire_places[i] - 1: - cuts[i-1] += 1 - - if cuts[-1] != len(alphas) - 1: - text.append('') - cuts.append(speech_lengths[0].tolist()) - cuts.insert(-1, (cuts[-1] + cuts[-2]) * 0.5) - sec_fire_places = np.array(cuts) * sacle_rate - for i in range(1, len(sec_fire_places) - 1): - start, end = sec_fire_places[i], sec_fire_places[i+1] - if i == len(sec_fire_places) - 2: - end = time - time_stamp_list.append([int(round(start, 2) * 1000) + begin, int(round(end, 2) * 1000) + begin]) - text = text[1:] - if npeak - 1 == nchar or npeak > nchar: - return time_stamp_list[:-1] - else: - return time_stamp_list - def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): START_END_THRESHOLD = 5 TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled