mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0 paraformer_streaming
This commit is contained in:
parent
78ffd04ac9
commit
a75bbb028e
@ -3,36 +3,44 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
# from funasr import AutoModel
|
||||
#
|
||||
# model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revison="v2.0.0")
|
||||
#
|
||||
# res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
|
||||
# print(res)
|
||||
from funasr import AutoModel
|
||||
|
||||
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
|
||||
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
|
||||
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
|
||||
|
||||
from funasr import AutoFrontend
|
||||
|
||||
frontend = AutoFrontend(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
|
||||
|
||||
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
|
||||
cache = {}
|
||||
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
cache=cache,
|
||||
is_final=True,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
import soundfile
|
||||
speech, sample_rate = soundfile.read("/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav")
|
||||
import os
|
||||
|
||||
speech, sample_rate = soundfile.read(os.path.expanduser('~')+
|
||||
"/.cache/modelscope/hub/damo/"+
|
||||
"speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/"+
|
||||
"example/asr_example.wav")
|
||||
|
||||
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
|
||||
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
|
||||
# first chunk, 600ms
|
||||
|
||||
cache = {}
|
||||
|
||||
for i in range(int(len((speech)-1)/chunk_stride+1)):
|
||||
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
|
||||
fbanks = frontend(input=speech_chunk,
|
||||
batch_size=2,
|
||||
cache=cache)
|
||||
|
||||
|
||||
# for batch_idx, fbank_dict in enumerate(fbanks):
|
||||
# res = model(**fbank_dict)
|
||||
# print(res)
|
||||
is_final = i == int(len((speech)-1)/chunk_stride+1)
|
||||
res = model(input=speech_chunk,
|
||||
cache=cache,
|
||||
is_final=is_final,
|
||||
chunk_size=chunk_size,
|
||||
encoder_chunk_look_back=encoder_chunk_look_back,
|
||||
decoder_chunk_look_back=decoder_chunk_look_back,
|
||||
)
|
||||
print(res)
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
# download model
|
||||
local_path_root=../modelscope_models
|
||||
mkdir -p ${local_path_root}
|
||||
local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||
git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
|
||||
|
||||
|
||||
python funasr/bin/train.py \
|
||||
+model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||
+token_list="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
|
||||
+train_data_set_list="data/list/audio_datasets.jsonl" \
|
||||
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
|
||||
+device="cpu"
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
|
||||
model_revision="v2.0.0"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -205,7 +205,8 @@ class CifPredictorV2(nn.Module):
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def forward_chunk(self, hidden, cache=None):
|
||||
def forward_chunk(self, hidden, cache=None, **kwargs):
|
||||
is_final = kwargs.get("is_final", False)
|
||||
batch_size, len_time, hidden_size = hidden.shape
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
@ -226,14 +227,14 @@ class CifPredictorV2(nn.Module):
|
||||
|
||||
if cache is not None and "chunk_size" in cache:
|
||||
alphas[:, :cache["chunk_size"][0]] = 0.0
|
||||
if "is_final" in cache and not cache["is_final"]:
|
||||
if not is_final:
|
||||
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
||||
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
||||
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
||||
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
||||
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
||||
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
||||
if cache is not None and "is_final" in cache and cache["is_final"]:
|
||||
if cache is not None and is_final:
|
||||
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
||||
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
||||
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|
||||
@ -277,7 +278,7 @@ class CifPredictorV2(nn.Module):
|
||||
|
||||
max_token_len = max(token_length)
|
||||
if max_token_len == 0:
|
||||
return hidden, torch.stack(token_length, 0)
|
||||
return hidden, torch.stack(token_length, 0), None, None
|
||||
list_ls = []
|
||||
for b in range(batch_size):
|
||||
pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
|
||||
@ -291,7 +292,7 @@ class CifPredictorV2(nn.Module):
|
||||
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
|
||||
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
|
||||
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
|
||||
return torch.stack(list_ls, 0), torch.stack(token_length, 0)
|
||||
return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
|
||||
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
|
||||
@ -64,8 +64,8 @@ class ParaformerStreaming(Paraformer):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
import pdb;
|
||||
pdb.set_trace()
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
|
||||
|
||||
|
||||
@ -375,11 +375,10 @@ class ParaformerStreaming(Paraformer):
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
|
||||
|
||||
def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
|
||||
|
||||
pre_acoustic_embeds, pre_token_length = \
|
||||
self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
||||
return pre_acoustic_embeds, pre_token_length
|
||||
def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
|
||||
is_final = kwargs.get("is_final", False)
|
||||
|
||||
return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
||||
decoder_outs = self.decoder(
|
||||
@ -416,7 +415,7 @@ class ParaformerStreaming(Paraformer):
|
||||
"chunk_size": chunk_size}
|
||||
cache["decoder"] = cache_decoder
|
||||
cache["frontend"] = {}
|
||||
cache["prev_samples"] = []
|
||||
cache["prev_samples"] = torch.empty(0)
|
||||
|
||||
return cache
|
||||
|
||||
@ -432,12 +431,12 @@ class ParaformerStreaming(Paraformer):
|
||||
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
|
||||
encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# predictor
|
||||
predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
|
||||
predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
|
||||
predictor_outs[2], predictor_outs[3]
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
@ -476,10 +475,7 @@ class ParaformerStreaming(Paraformer):
|
||||
)
|
||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
||||
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||
ibest_writer = None
|
||||
if ibest_writer is None and kwargs.get("output_dir") is not None:
|
||||
writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
@ -490,22 +486,15 @@ class ParaformerStreaming(Paraformer):
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||
|
||||
if tokenizer is not None:
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
|
||||
result_i = {"key": key[i], "text": text_postprocessed}
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
# ibest_writer["text"][key[i]] = text
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
else:
|
||||
result_i = {"key": key[i], "token_int": token_int}
|
||||
results.append(result_i)
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
# text = tokenizer.tokens2text(token)
|
||||
|
||||
result_i = token
|
||||
|
||||
|
||||
results.extend(result_i)
|
||||
|
||||
return results
|
||||
|
||||
@ -515,6 +504,7 @@ class ParaformerStreaming(Paraformer):
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
cache: dict={},
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@ -526,9 +516,10 @@ class ParaformerStreaming(Paraformer):
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
cache = kwargs.get("cache", {})
|
||||
|
||||
if len(cache) == 0:
|
||||
self.init_cache(cache, **kwargs)
|
||||
_is_final = kwargs.get("is_final", False)
|
||||
|
||||
meta_data = {}
|
||||
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
|
||||
@ -542,22 +533,41 @@ class ParaformerStreaming(Paraformer):
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
assert len(audio_sample_list) == 1, "batch_size must be set 1"
|
||||
|
||||
audio_sample = cache["prev_samples"] + audio_sample_list[0]
|
||||
audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
|
||||
|
||||
n = len(audio_sample) // chunk_stride_samples
|
||||
m = len(audio_sample) % chunk_stride_samples
|
||||
n = len(audio_sample) // chunk_stride_samples + int(_is_final)
|
||||
m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
|
||||
tokens = []
|
||||
for i in range(n):
|
||||
kwargs["is_final"] = _is_final and i == n -1
|
||||
audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
|
||||
|
||||
# extract fbank feats
|
||||
speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend, cache=cache["frontend"])
|
||||
frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
|
||||
tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
|
||||
tokens.extend(tokens_i)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
|
||||
|
||||
result_i = {"key": key[0], "text": text_postprocessed}
|
||||
result = [result_i]
|
||||
|
||||
|
||||
cache["prev_samples"] = audio_sample[:-m]
|
||||
if _is_final:
|
||||
self.init_cache(cache, **kwargs)
|
||||
|
||||
if kwargs.get("output_dir"):
|
||||
writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = writer[f"{1}best_recog"]
|
||||
ibest_writer["token"][key[0]] = " ".join(tokens)
|
||||
ibest_writer["text"][key[0]] = text_postprocessed
|
||||
|
||||
return result, meta_data
|
||||
|
||||
|
||||
|
||||
@ -423,7 +423,9 @@ class SANMEncoderChunkOpt(nn.Module):
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
cache: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
is_final = kwargs.get("is_final", False)
|
||||
xs_pad *= self.output_size() ** 0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
|
||||
@ -43,7 +43,7 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs:
|
||||
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
|
||||
data_or_path_or_list = np.squeeze(data_or_path_or_list) # [n_samples,]
|
||||
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
|
||||
else:
|
||||
pass
|
||||
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user