funasr1.0 paraformer_streaming

This commit is contained in:
游雁 2024-01-11 17:36:30 +08:00
parent 78ffd04ac9
commit a75bbb028e
7 changed files with 85 additions and 78 deletions

View File

@ -3,36 +3,44 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT) # MIT License (https://opensource.org/licenses/MIT)
# from funasr import AutoModel from funasr import AutoModel
#
# model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revison="v2.0.0")
#
# res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
# print(res)
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
from funasr import AutoFrontend model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
cache = {}
frontend = AutoFrontend(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0") res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
cache=cache,
is_final=True,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)
import soundfile import soundfile
speech, sample_rate = soundfile.read("/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav") import os
speech, sample_rate = soundfile.read(os.path.expanduser('~')+
"/.cache/modelscope/hub/damo/"+
"speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/"+
"example/asr_example.wav")
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
chunk_stride = chunk_size[1] * 960 # 600ms、480ms chunk_stride = chunk_size[1] * 960 # 600ms、480ms
# first chunk, 600ms
cache = {} cache = {}
for i in range(int(len((speech)-1)/chunk_stride+1)): for i in range(int(len((speech)-1)/chunk_stride+1)):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
fbanks = frontend(input=speech_chunk, is_final = i == int(len((speech)-1)/chunk_stride+1)
batch_size=2, res = model(input=speech_chunk,
cache=cache) cache=cache,
is_final=is_final,
chunk_size=chunk_size,
# for batch_idx, fbank_dict in enumerate(fbanks): encoder_chunk_look_back=encoder_chunk_look_back,
# res = model(**fbank_dict) decoder_chunk_look_back=decoder_chunk_look_back,
# print(res) )
print(res)

View File

@ -1,14 +0,0 @@
# download model
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
python funasr/bin/train.py \
+model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+token_list="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
+train_data_set_list="data/list/audio_datasets.jsonl" \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu"

View File

@ -1,5 +1,5 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model_revision="v2.0.0" model_revision="v2.0.0"
python funasr/bin/inference.py \ python funasr/bin/inference.py \

View File

@ -205,7 +205,8 @@ class CifPredictorV2(nn.Module):
return acoustic_embeds, token_num, alphas, cif_peak return acoustic_embeds, token_num, alphas, cif_peak
def forward_chunk(self, hidden, cache=None): def forward_chunk(self, hidden, cache=None, **kwargs):
is_final = kwargs.get("is_final", False)
batch_size, len_time, hidden_size = hidden.shape batch_size, len_time, hidden_size = hidden.shape
h = hidden h = hidden
context = h.transpose(1, 2) context = h.transpose(1, 2)
@ -226,14 +227,14 @@ class CifPredictorV2(nn.Module):
if cache is not None and "chunk_size" in cache: if cache is not None and "chunk_size" in cache:
alphas[:, :cache["chunk_size"][0]] = 0.0 alphas[:, :cache["chunk_size"][0]] = 0.0
if "is_final" in cache and not cache["is_final"]: if not is_final:
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
if cache is not None and "is_final" in cache and cache["is_final"]: if cache is not None and is_final:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1)) tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
@ -277,7 +278,7 @@ class CifPredictorV2(nn.Module):
max_token_len = max(token_length) max_token_len = max(token_length)
if max_token_len == 0: if max_token_len == 0:
return hidden, torch.stack(token_length, 0) return hidden, torch.stack(token_length, 0), None, None
list_ls = [] list_ls = []
for b in range(batch_size): for b in range(batch_size):
pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device) pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
@ -291,7 +292,7 @@ class CifPredictorV2(nn.Module):
cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
return torch.stack(list_ls, 0), torch.stack(token_length, 0) return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):

View File

@ -64,8 +64,8 @@ class ParaformerStreaming(Paraformer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
import pdb; # import pdb;
pdb.set_trace() # pdb.set_trace()
self.sampling_ratio = kwargs.get("sampling_ratio", 0.2) self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
@ -375,11 +375,10 @@ class ParaformerStreaming(Paraformer):
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None): def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
is_final = kwargs.get("is_final", False)
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"]) return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
return pre_acoustic_embeds, pre_token_length
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
decoder_outs = self.decoder( decoder_outs = self.decoder(
@ -416,7 +415,7 @@ class ParaformerStreaming(Paraformer):
"chunk_size": chunk_size} "chunk_size": chunk_size}
cache["decoder"] = cache_decoder cache["decoder"] = cache_decoder
cache["frontend"] = {} cache["frontend"] = {}
cache["prev_samples"] = [] cache["prev_samples"] = torch.empty(0)
return cache return cache
@ -432,12 +431,12 @@ class ParaformerStreaming(Paraformer):
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
# Encoder # Encoder
encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache) encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
if isinstance(encoder_out, tuple): if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0] encoder_out = encoder_out[0]
# predictor # predictor
predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache) predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
predictor_outs[2], predictor_outs[3] predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.round().long() pre_token_length = pre_token_length.round().long()
@ -476,10 +475,7 @@ class ParaformerStreaming(Paraformer):
) )
nbest_hyps = [Hypothesis(yseq=yseq, score=score)] nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for nbest_idx, hyp in enumerate(nbest_hyps): for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results # remove sos/eos and get results
last_pos = -1 last_pos = -1
if isinstance(hyp.yseq, list): if isinstance(hyp.yseq, list):
@ -490,22 +486,15 @@ class ParaformerStreaming(Paraformer):
# remove blank symbol id, which is assumed to be 0 # remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
if tokenizer is not None:
# Change integer-ids to tokens # Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int) token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token) # text = tokenizer.tokens2text(token)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) result_i = token
result_i = {"key": key[i], "text": text_postprocessed}
results.extend(result_i)
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["text"][key[i]] = text
ibest_writer["text"][key[i]] = text_postprocessed
else:
result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
return results return results
@ -515,6 +504,7 @@ class ParaformerStreaming(Paraformer):
key: list = None, key: list = None,
tokenizer=None, tokenizer=None,
frontend=None, frontend=None,
cache: dict={},
**kwargs, **kwargs,
): ):
@ -526,9 +516,10 @@ class ParaformerStreaming(Paraformer):
self.init_beam_search(**kwargs) self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1) self.nbest = kwargs.get("nbest", 1)
cache = kwargs.get("cache", {})
if len(cache) == 0: if len(cache) == 0:
self.init_cache(cache, **kwargs) self.init_cache(cache, **kwargs)
_is_final = kwargs.get("is_final", False)
meta_data = {} meta_data = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5]) chunk_size = kwargs.get("chunk_size", [0, 10, 5])
@ -542,22 +533,41 @@ class ParaformerStreaming(Paraformer):
meta_data["load_data"] = f"{time2 - time1:0.3f}" meta_data["load_data"] = f"{time2 - time1:0.3f}"
assert len(audio_sample_list) == 1, "batch_size must be set 1" assert len(audio_sample_list) == 1, "batch_size must be set 1"
audio_sample = cache["prev_samples"] + audio_sample_list[0] audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
n = len(audio_sample) // chunk_stride_samples n = len(audio_sample) // chunk_stride_samples + int(_is_final)
m = len(audio_sample) % chunk_stride_samples m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
tokens = []
for i in range(n): for i in range(n):
kwargs["is_final"] = _is_final and i == n -1
audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples] audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
# extract fbank feats # extract fbank feats
speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
frontend=frontend, cache=cache["frontend"]) frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
time3 = time.perf_counter() time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
result_i = self.generate_chunk(speech, speech_lengths, **kwargs) tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
tokens.extend(tokens_i)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
result_i = {"key": key[0], "text": text_postprocessed}
result = [result_i]
cache["prev_samples"] = audio_sample[:-m] cache["prev_samples"] = audio_sample[:-m]
if _is_final:
self.init_cache(cache, **kwargs)
if kwargs.get("output_dir"):
writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = writer[f"{1}best_recog"]
ibest_writer["token"][key[0]] = " ".join(tokens)
ibest_writer["text"][key[0]] = text_postprocessed
return result, meta_data

View File

@ -423,7 +423,9 @@ class SANMEncoderChunkOpt(nn.Module):
xs_pad: torch.Tensor, xs_pad: torch.Tensor,
ilens: torch.Tensor, ilens: torch.Tensor,
cache: dict = None, cache: dict = None,
**kwargs,
): ):
is_final = kwargs.get("is_final", False)
xs_pad *= self.output_size() ** 0.5 xs_pad *= self.output_size() ** 0.5
if self.embed is None: if self.embed is None:
xs_pad = xs_pad xs_pad = xs_pad

View File

@ -43,7 +43,7 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs:
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None: elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list) data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = np.squeeze(data_or_path_or_list) # [n_samples,] data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
else: else:
pass pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data") # print(f"unsupport data type: {data_or_path_or_list}, return raw data")