funasr1.0 paraformer_streaming

This commit is contained in:
游雁 2024-01-11 00:09:36 +08:00
parent d342c642fa
commit 47088b8d1e
2 changed files with 121 additions and 34 deletions

View File

@ -375,7 +375,7 @@ class ParaformerStreaming(Paraformer):
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
def calc_predictor_chunk(self, encoder_out, cache=None):
def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
@ -389,48 +389,72 @@ class ParaformerStreaming(Paraformer):
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
)
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out
return decoder_out, ys_pad_lens
def init_cache(self, cache: dict = {}, **kwargs):
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
def generate(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
tokenizer=None,
**kwargs,
):
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
"tail_chunk": False}
cache["encoder"] = cache_encoder
is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
print(is_use_ctc)
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
"chunk_size": chunk_size}
cache["decoder"] = cache_decoder
cache["frontend"] = {}
cache["prev_samples"] = []
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(speech, speech_lengths, **kwargs)
self.nbest = kwargs.get("nbest", 1)
return cache
def generate_chunk(self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
cache = kwargs.get("cache", {})
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
# Forward Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# Encoder
encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
pre_token_length)
decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,
cache=cache
)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
for i in range(b):
x = encoder_out[i, :encoder_out_lens[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
@ -451,9 +475,11 @@ class ParaformerStreaming(Paraformer):
[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
@ -462,15 +488,76 @@ class ParaformerStreaming(Paraformer):
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
timestamp = []
results.append((text, token, timestamp))
if tokenizer is not None:
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "text": text_postprocessed}
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["text"][key[i]] = text
ibest_writer["text"][key[i]] = text_postprocessed
else:
result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
return results
def generate(self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
cache = kwargs.get("cache", {})
if len(cache) == 0:
self.init_cache(cache, **kwargs)
meta_data = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
chunk_stride_samples = chunk_size[1] * 960 # 600ms
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
assert len(audio_sample_list) == 1, "batch_size must be set 1"
audio_sample = cache["prev_samples"] + audio_sample_list[0]
n = len(audio_sample) // chunk_stride_samples
m = len(audio_sample) % chunk_stride_samples
for i in range(n):
audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
# extract fbank feats
speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
frontend=frontend, cache=cache["frontend"])
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
cache["prev_samples"] = audio_sample[:-m]

View File

@ -13,7 +13,7 @@ def get_readme():
MODULE_NAME = 'funasr_onnx'
VERSION_NUM = '0.2.4'
VERSION_NUM = '0.2.5'
setuptools.setup(
name=MODULE_NAME,