diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py index d35960383..5209f319a 100644 --- a/funasr/runtime/python/onnxruntime/demo.py +++ b/funasr/runtime/python/onnxruntime/demo.py @@ -2,7 +2,7 @@ from rapid_paraformer import Paraformer model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" -model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +# model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model = Paraformer(model_dir, batch_size=1) diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py index ed9b030df..9b8a67bb0 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py @@ -41,17 +41,16 @@ class Paraformer(): ) self.ort_infer = OrtInferSession(model_file, device_id) self.batch_size = batch_size + self.plot = True def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) - asr_res = [] for beg_idx in range(0, waveform_nums, self.batch_size): res = {} end_idx = min(waveform_nums, beg_idx + self.batch_size) feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) - try: outputs = self.infer(feats, feats_len) am_scores, valid_token_lens = outputs[0], outputs[1] @@ -68,11 +67,17 @@ class Paraformer(): preds, raw_token = self.decode(am_scores, valid_token_lens)[0] res['preds'] = preds if us_cif_peak is not None: - timestamp = time_stamp_lfr6_onnx(us_cif_peak, copy.copy(raw_token)) + timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak, copy.copy(raw_token)) res['timestamp'] = timestamp + if self.plot: + self.plot_wave_timestamp(waveform_list[0], timestamp_total) asr_res.append(res) return asr_res + def plot_wave_timestamp(self, wav, text_timestamp): + # TODO: Plot the wav and timestamp results with matplotlib + import pdb; pdb.set_trace() + def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: def load_wav(path: str) -> np.ndarray: diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/timestamp_utils.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/timestamp_utils.py index 32dd8498e..767e864fc 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/timestamp_utils.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/timestamp_utils.py @@ -9,7 +9,6 @@ def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0): TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled cif_peak = us_cif_peak.reshape(-1) num_frames = cif_peak.shape[-1] - import pdb; pdb.set_trace() if char_list[-1] == '': char_list = char_list[:-1] # char_list = [i for i in text] @@ -49,11 +48,11 @@ def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0): timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0 timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0 assert len(new_char_list) == len(timestamp_list) - res_txt = "" + res_total = [] for char, timestamp in zip(new_char_list, timestamp_list): - res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1]) + res_total.append([char, timestamp[0], timestamp[1]]) # += "{} {} {};".format(char, timestamp[0], timestamp[1]) res = [] for char, timestamp in zip(new_char_list, timestamp_list): if char != '': res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) - return res \ No newline at end of file + return res, res_total \ No newline at end of file