funasr1.0 streaming demo

This commit is contained in:
游雁 2024-01-12 12:05:08 +08:00
parent a0d77813ac
commit 40d1f80030
5 changed files with 67 additions and 75 deletions

View File

@ -4,8 +4,33 @@
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
wav_file = "/Users/zhifu/funasr_github/test_local/asr_example.wav"
chunk_size = 60000 # ms
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-streaming", model_revision="v2.0.0")
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0")
res = model(input=wav_file,
chunk_size=chunk_size,
)
print(res)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
print(res)
#
# import soundfile
# import os
#
# # wav_file = os.path.join(model.model_path, "example/vad_example.wav")
# speech, sample_rate = soundfile.read(wav_file)
#
# chunk_stride = int(chunk_size * 16000 / 1000)
#
# cache = {}
#
# for i in range(int(len((speech)-1)/chunk_stride+1)):
# speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
# is_final = i == int(len((speech)-1)/chunk_stride+1)
# res = model(input=speech_chunk,
# cache=cache,
# is_final=is_final,
# chunk_size=chunk_size,
# )
# print(res)

View File

@ -593,15 +593,16 @@ class FsmnVAD(nn.Module):
results = []
for i in range(batch_size):
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
if ibest_writer is not None:
ibest_writer["text"][key[i]] = segments[i]
result_i = {"key": key[i], "value": segments[i]}
results.append(result_i)
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
return results, meta_data
def DetectCommonFrames(self) -> int:

View File

@ -496,7 +496,7 @@ class FsmnVADStreaming(nn.Module):
def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
):
if not cache:
if len(cache) == 0:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
@ -521,13 +521,15 @@ class FsmnVADStreaming(nn.Module):
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments, cache
return segments
def init_cache(self, cache: dict = {}, **kwargs):
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
return cache
def generate(self,
data_in,
data_lengths=None,
@ -543,7 +545,7 @@ class FsmnVADStreaming(nn.Module):
meta_data = {}
chunk_size = kwargs.get("chunk_size", 50) # 50ms
chunk_stride_samples = chunk_size * 16
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
time1 = time.perf_counter()
cfg = {"is_final": kwargs.get("is_final", False)}
@ -552,7 +554,7 @@ class FsmnVADStreaming(nn.Module):
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
**cfg,
cache=cfg,
)
_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
@ -562,9 +564,9 @@ class FsmnVADStreaming(nn.Module):
audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
n = len(audio_sample) // chunk_stride_samples + int(_is_final)
m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))
tokens = []
n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
segments = []
for i in range(n):
kwargs["is_final"] = _is_final and i == n - 1
audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
@ -576,58 +578,22 @@ class FsmnVADStreaming(nn.Module):
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
meta_data = {}
audio_sample_list = [data_in]
if isinstance(data_in, torch.Tensor): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
# b. Forward Encoder streaming
t_offset = 0
feats = speech
feats_len = speech_lengths.max().item()
waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
cache = kwargs.get("cache", {})
batch_size = kwargs.get("batch_size", 1)
step = min(feats_len, 6000)
segments = [[]] * batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final = True
else:
is_final = False
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final": is_final,
"cache": cache
"feats": speech,
"waveform": cache["frontend"]["waveforms"],
"is_final": kwargs["is_final"],
"cache": cache["encoder"]
}
segments_i = self.forward(**batch)
print(segments_i)
segments.extend(segments_i)
segments_part, cache = self.forward(**batch)
if segments_part:
for batch_num in range(0, batch_size):
segments[batch_num] += segments_part[batch_num]
cache["prev_samples"] = audio_sample[:-m]
if _is_final:
self.init_cache(cache, **kwargs)
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
@ -635,16 +601,15 @@ class FsmnVADStreaming(nn.Module):
ibest_writer = writer[f"{1}best_recog"]
results = []
for i in range(batch_size):
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
if ibest_writer is not None:
ibest_writer["text"][key[i]] = segments[i]
result_i = {"key": key[0], "value": segments}
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
result_i = json.dumps(result_i)
results.append(result_i)
if ibest_writer is not None:
ibest_writer["text"][key[0]] = segments
result_i = {"key": key[i], "value": segments[i]}
results.append(result_i)
return results, meta_data

View File

@ -523,7 +523,7 @@ class ParaformerStreaming(Paraformer):
meta_data = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
chunk_stride_samples = chunk_size[1] * 960 # 600ms
chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
time1 = time.perf_counter()
cfg = {"is_final": kwargs.get("is_final", False)}
@ -532,7 +532,7 @@ class ParaformerStreaming(Paraformer):
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
**cfg,
cache=cfg,
)
_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
@ -542,8 +542,8 @@ class ParaformerStreaming(Paraformer):
audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
n = len(audio_sample) // chunk_stride_samples + int(_is_final)
m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
tokens = []
for i in range(n):
kwargs["is_final"] = _is_final and i == n -1

View File

@ -48,7 +48,8 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs:
pass
# if data_in is a file or url, set is_final=True
kwargs["is_final"] = True
if "cache" in kwargs:
kwargs["cache"]["is_final"] = True
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point