diff --git a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py index e6b33d464..3545ccf48 100644 --- a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py +++ b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py @@ -22,6 +22,8 @@ class Paraformer(): def __init__(self, model_dir: Union[str, Path] = None, batch_size: int = 1, device_id: Union[str, int] = "-1", + plot_timestamp_to: str = "", + pred_bias: int = 1, ): if not Path(model_dir).exists(): @@ -40,17 +42,17 @@ class Paraformer(): ) self.ort_infer = torch.jit.load(model_file) self.batch_size = batch_size + self.plot_timestamp_to = plot_timestamp_to + self.pred_bias = pred_bias def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) - asr_res = [] for beg_idx in range(0, waveform_nums, self.batch_size): - res = {} + end_idx = min(waveform_nums, beg_idx + self.batch_size) feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) - try: outputs = self.ort_infer(feats, feats_len) am_scores, valid_token_lens = outputs[0], outputs[1] @@ -65,15 +67,42 @@ class Paraformer(): preds = [''] else: am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy() - preds, raw_token = self.decode(am_scores, valid_token_lens)[0] - res['preds'] = preds - if us_cif_peak is not None: - us_alphas, us_cif_peak = us_alphas.cpu().numpy(), us_cif_peak.cpu().numpy() - timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(raw_token), log=False) - res['timestamp'] = timestamp - asr_res.append(res) + preds = self.decode(am_scores, valid_token_lens) + if us_cif_peak is None: + for pred in preds: + asr_res.append({'preds': pred}) + else: + for pred, us_cif_peak_ in zip(preds, us_cif_peak): + text, tokens = pred + timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak_, copy.copy(tokens)) + if len(self.plot_timestamp_to): + self.plot_wave_timestamp(waveform_list[0], timestamp_total, self.plot_timestamp_to) + asr_res.append({'preds': text, 'timestamp': timestamp}) return asr_res + def plot_wave_timestamp(self, wav, text_timestamp, dest): + # TODO: Plot the wav and timestamp results with matplotlib + import matplotlib + matplotlib.use('Agg') + matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports + import matplotlib.pyplot as plt + fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320) + ax2 = ax1.twinx() + ax2.set_ylim([0, 2.0]) + # plot waveform + ax1.set_ylim([-0.3, 0.3]) + time = np.arange(wav.shape[0]) / 16000 + ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4) + # plot lines and text + for (char, start, end) in text_timestamp: + ax1.vlines(start, -0.3, 0.3, ls='--') + ax1.vlines(end, -0.3, 0.3, ls='--') + x_adj = 0.045 if char != '' else 0.12 + ax1.text((start + end) * 0.5 - x_adj, 0, char) + # plt.legend() + plotname = "{}/timestamp.png".format(dest) + plt.savefig(plotname, bbox_inches='tight') + def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: def load_wav(path: str) -> np.ndarray: @@ -148,9 +177,7 @@ class Paraformer(): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - # token = token[:valid_token_num-1] + token = token[:valid_token_num-self.pred_bias] texts = sentence_postprocess(token) - text = texts[0] - # text = self.tokenizer.tokens2text(token) - return text, token + return texts diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py index 3135b4d82..1b887daec 100644 --- a/funasr/runtime/python/onnxruntime/demo.py +++ b/funasr/runtime/python/onnxruntime/demo.py @@ -1,12 +1,15 @@ from rapid_paraformer import Paraformer -model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" -# model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +#model_dir = "/Users/shixian/code/funasr/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +#model_dir = "/Users/shixian/code/funasr/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model_dir = "/Users/shixian/code/funasr/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch" -model = Paraformer(model_dir, batch_size=1) +# if you use paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch, you should set pred_bias=0 +# plot_timestamp_to works only when using speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch +model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0) -wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'] +wav_path = "/Users/shixian/code/funasr/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav" result = model(wav_path) print(result) \ No newline at end of file diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py index 4a55bdfea..850f007c7 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py @@ -24,7 +24,8 @@ class Paraformer(): def __init__(self, model_dir: Union[str, Path] = None, batch_size: int = 1, device_id: Union[str, int] = "-1", - plot_timestamp: bool = False, + plot_timestamp_to: str = "", + pred_bias: int = 1, ): if not Path(model_dir).exists(): @@ -43,14 +44,15 @@ class Paraformer(): ) self.ort_infer = OrtInferSession(model_file, device_id) self.batch_size = batch_size - self.plot = plot_timestamp + self.plot_timestamp_to = plot_timestamp_to + self.pred_bias = pred_bias def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) asr_res = [] for beg_idx in range(0, waveform_nums, self.batch_size): - res = {} + end_idx = min(waveform_nums, beg_idx + self.batch_size) feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) try: @@ -66,17 +68,20 @@ class Paraformer(): logging.warning("input wav is silence or noise") preds = [''] else: - preds, raw_token = self.decode(am_scores, valid_token_lens)[0] - res['preds'] = preds - if us_cif_peak is not None: - timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak, copy.copy(raw_token)) - res['timestamp'] = timestamp - if self.plot: - self.plot_wave_timestamp(waveform_list[0], timestamp_total) - asr_res.append(res) + preds = self.decode(am_scores, valid_token_lens) + if us_cif_peak is None: + for pred in preds: + asr_res.append({'preds': pred}) + else: + for pred, us_cif_peak_ in zip(preds, us_cif_peak): + text, tokens = pred + timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak_, copy.copy(tokens)) + if len(self.plot_timestamp_to): + self.plot_wave_timestamp(waveform_list[0], timestamp_total, self.plot_timestamp_to) + asr_res.append({'preds': text, 'timestamp': timestamp}) return asr_res - def plot_wave_timestamp(self, wav, text_timestamp): + def plot_wave_timestamp(self, wav, text_timestamp, dest): # TODO: Plot the wav and timestamp results with matplotlib import matplotlib matplotlib.use('Agg') @@ -96,7 +101,7 @@ class Paraformer(): x_adj = 0.045 if char != '' else 0.12 ax1.text((start + end) * 0.5 - x_adj, 0, char) # plt.legend() - plotname = "funasr/runtime/python/onnxruntime/debug.png" + plotname = "{}/timestamp.png".format(dest) plt.savefig(plotname, bbox_inches='tight') def load_data(self, @@ -171,9 +176,7 @@ class Paraformer(): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - # token = token[:valid_token_num-1] + token = token[:valid_token_num-self.pred_bias] texts = sentence_postprocess(token) - text = texts[0] - # text = self.tokenizer.tokens2text(token) - return text, token + return texts