Merge pull request #184 from alibaba-damo-academy/dev_timestamp

Dev timestamp
This commit is contained in:
zhifu gao 2023-03-03 21:53:20 +08:00 committed by GitHub
commit 00c1d9119c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 35 deletions

View File

@ -22,6 +22,8 @@ class Paraformer():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
pred_bias: int = 1,
):
if not Path(model_dir).exists():
@ -40,17 +42,17 @@ class Paraformer():
)
self.ort_infer = torch.jit.load(model_file)
self.batch_size = batch_size
self.plot_timestamp_to = plot_timestamp_to
self.pred_bias = pred_bias
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
res = {}
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
outputs = self.ort_infer(feats, feats_len)
am_scores, valid_token_lens = outputs[0], outputs[1]
@ -65,15 +67,42 @@ class Paraformer():
preds = ['']
else:
am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy()
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
res['preds'] = preds
if us_cif_peak is not None:
us_alphas, us_cif_peak = us_alphas.cpu().numpy(), us_cif_peak.cpu().numpy()
timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(raw_token), log=False)
res['timestamp'] = timestamp
asr_res.append(res)
preds = self.decode(am_scores, valid_token_lens)
if us_cif_peak is None:
for pred in preds:
asr_res.append({'preds': pred})
else:
for pred, us_cif_peak_ in zip(preds, us_cif_peak):
text, tokens = pred
timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak_, copy.copy(tokens))
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(waveform_list[0], timestamp_total, self.plot_timestamp_to)
asr_res.append({'preds': text, 'timestamp': timestamp})
return asr_res
def plot_wave_timestamp(self, wav, text_timestamp, dest):
# TODO: Plot the wav and timestamp results with matplotlib
import matplotlib
matplotlib.use('Agg')
matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports
import matplotlib.pyplot as plt
fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
ax2 = ax1.twinx()
ax2.set_ylim([0, 2.0])
# plot waveform
ax1.set_ylim([-0.3, 0.3])
time = np.arange(wav.shape[0]) / 16000
ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
# plot lines and text
for (char, start, end) in text_timestamp:
ax1.vlines(start, -0.3, 0.3, ls='--')
ax1.vlines(end, -0.3, 0.3, ls='--')
x_adj = 0.045 if char != '<sil>' else 0.12
ax1.text((start + end) * 0.5 - x_adj, 0, char)
# plt.legend()
plotname = "{}/timestamp.png".format(dest)
plt.savefig(plotname, bbox_inches='tight')
def load_data(self,
wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
@ -148,9 +177,7 @@ class Paraformer():
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
# token = token[:valid_token_num-1]
token = token[:valid_token_num-self.pred_bias]
texts = sentence_postprocess(token)
text = texts[0]
# text = self.tokenizer.tokens2text(token)
return text, token
return texts

View File

@ -1,12 +1,15 @@
from rapid_paraformer import Paraformer
model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
# model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
#model_dir = "/Users/shixian/code/funasr/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
#model_dir = "/Users/shixian/code/funasr/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_dir = "/Users/shixian/code/funasr/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch"
model = Paraformer(model_dir, batch_size=1)
# if you use paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch, you should set pred_bias=0
# plot_timestamp_to works only when using speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0)
wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
wav_path = "/Users/shixian/code/funasr/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav"
result = model(wav_path)
print(result)

View File

@ -24,7 +24,8 @@ class Paraformer():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp: bool = False,
plot_timestamp_to: str = "",
pred_bias: int = 1,
):
if not Path(model_dir).exists():
@ -43,14 +44,15 @@ class Paraformer():
)
self.ort_infer = OrtInferSession(model_file, device_id)
self.batch_size = batch_size
self.plot = plot_timestamp
self.plot_timestamp_to = plot_timestamp_to
self.pred_bias = pred_bias
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
res = {}
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
@ -66,17 +68,20 @@ class Paraformer():
logging.warning("input wav is silence or noise")
preds = ['']
else:
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
res['preds'] = preds
if us_cif_peak is not None:
timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak, copy.copy(raw_token))
res['timestamp'] = timestamp
if self.plot:
self.plot_wave_timestamp(waveform_list[0], timestamp_total)
asr_res.append(res)
preds = self.decode(am_scores, valid_token_lens)
if us_cif_peak is None:
for pred in preds:
asr_res.append({'preds': pred})
else:
for pred, us_cif_peak_ in zip(preds, us_cif_peak):
text, tokens = pred
timestamp, timestamp_total = time_stamp_lfr6_onnx(us_cif_peak_, copy.copy(tokens))
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(waveform_list[0], timestamp_total, self.plot_timestamp_to)
asr_res.append({'preds': text, 'timestamp': timestamp})
return asr_res
def plot_wave_timestamp(self, wav, text_timestamp):
def plot_wave_timestamp(self, wav, text_timestamp, dest):
# TODO: Plot the wav and timestamp results with matplotlib
import matplotlib
matplotlib.use('Agg')
@ -96,7 +101,7 @@ class Paraformer():
x_adj = 0.045 if char != '<sil>' else 0.12
ax1.text((start + end) * 0.5 - x_adj, 0, char)
# plt.legend()
plotname = "funasr/runtime/python/onnxruntime/debug.png"
plotname = "{}/timestamp.png".format(dest)
plt.savefig(plotname, bbox_inches='tight')
def load_data(self,
@ -171,9 +176,7 @@ class Paraformer():
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
# token = token[:valid_token_num-1]
token = token[:valid_token_num-self.pred_bias]
texts = sentence_postprocess(token)
text = texts[0]
# text = self.tokenizer.tokens2text(token)
return text, token
return texts