This commit is contained in:
shixian.shi 2024-01-17 19:21:08 +08:00
parent 9a9c3b75b5
commit 7458e39ff0
2 changed files with 20 additions and 18 deletions

View File

@ -11,6 +11,7 @@ res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/Ma
print(res) print(res)
''' can not use currently
from funasr import AutoFrontend from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2") frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
@ -19,4 +20,5 @@ fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/
for batch_idx, fbank_dict in enumerate(fbanks): for batch_idx, fbank_dict in enumerate(fbanks):
res = model.generate(**fbank_dict) res = model.generate(**fbank_dict)
print(res) print(res)
'''

View File

@ -235,23 +235,23 @@ class BiCifParaformer(Paraformer):
self.nbest = kwargs.get("nbest", 1) self.nbest = kwargs.get("nbest", 1)
meta_data = {} meta_data = {}
if isinstance(data_in, torch.Tensor): # fbank # if isinstance(data_in, torch.Tensor): # fbank
speech, speech_lengths = data_in, data_lengths # speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3: # if len(speech.shape) < 3:
speech = speech[None, :, :] # speech = speech[None, :, :]
if speech_lengths is None: # if speech_lengths is None:
speech_lengths = speech.shape[1] # speech_lengths = speech.shape[1]
else: # else:
# extract fbank feats # extract fbank feats
time1 = time.perf_counter() time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter() time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}" meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend) frontend=frontend)
time3 = time.perf_counter() time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"]) speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"])