mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0 paraformer_streaming
This commit is contained in:
parent
d342c642fa
commit
47088b8d1e
@ -375,7 +375,7 @@ class ParaformerStreaming(Paraformer):
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
|
||||
|
||||
def calc_predictor_chunk(self, encoder_out, cache=None):
|
||||
def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
|
||||
|
||||
pre_acoustic_embeds, pre_token_length = \
|
||||
self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
||||
@ -389,48 +389,72 @@ class ParaformerStreaming(Paraformer):
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
|
||||
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
||||
def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
|
||||
decoder_outs = self.decoder.forward_chunk(
|
||||
encoder_out, sematic_embeds, cache["decoder"]
|
||||
)
|
||||
decoder_out = decoder_outs
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out
|
||||
return decoder_out, ys_pad_lens
|
||||
|
||||
def init_cache(self, cache: dict = {}, **kwargs):
|
||||
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
|
||||
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
|
||||
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
|
||||
batch_size = 1
|
||||
|
||||
def generate(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
tokenizer=None,
|
||||
**kwargs,
|
||||
):
|
||||
enc_output_size = kwargs["encoder_conf"]["output_size"]
|
||||
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
|
||||
cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
|
||||
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
|
||||
"encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
|
||||
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
|
||||
"tail_chunk": False}
|
||||
cache["encoder"] = cache_encoder
|
||||
|
||||
is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||
print(is_use_ctc)
|
||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||
cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
|
||||
"chunk_size": chunk_size}
|
||||
cache["decoder"] = cache_decoder
|
||||
cache["frontend"] = {}
|
||||
cache["prev_samples"] = []
|
||||
|
||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(speech, speech_lengths, **kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
return cache
|
||||
|
||||
def generate_chunk(self,
|
||||
speech,
|
||||
speech_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
cache = kwargs.get("cache", {})
|
||||
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
# Forward Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# predictor
|
||||
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
|
||||
predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
|
||||
predictor_outs[2], predictor_outs[3]
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
|
||||
pre_token_length)
|
||||
decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
|
||||
encoder_out_lens,
|
||||
pre_acoustic_embeds,
|
||||
pre_token_length,
|
||||
cache=cache
|
||||
)
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
if isinstance(key[0], (list, tuple)):
|
||||
key = key[0]
|
||||
for i in range(b):
|
||||
x = encoder_out[i, :encoder_out_lens[i], :]
|
||||
am_scores = decoder_out[i, :pre_token_length[i], :]
|
||||
@ -451,9 +475,11 @@ class ParaformerStreaming(Paraformer):
|
||||
[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
|
||||
)
|
||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
for nbest_idx, hyp in enumerate(nbest_hyps):
|
||||
ibest_writer = None
|
||||
if ibest_writer is None and kwargs.get("output_dir") is not None:
|
||||
writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
@ -462,15 +488,76 @@ class ParaformerStreaming(Paraformer):
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
|
||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
timestamp = []
|
||||
|
||||
results.append((text, token, timestamp))
|
||||
if tokenizer is not None:
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
|
||||
result_i = {"key": key[i], "text": text_postprocessed}
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
# ibest_writer["text"][key[i]] = text
|
||||
ibest_writer["text"][key[i]] = text_postprocessed
|
||||
else:
|
||||
result_i = {"key": key[i], "token_int": token_int}
|
||||
results.append(result_i)
|
||||
|
||||
return results
|
||||
|
||||
def generate(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# init beamsearch
|
||||
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
|
||||
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
|
||||
if self.beam_search is None and (is_use_lm or is_use_ctc):
|
||||
logging.info("enable beam_search")
|
||||
self.init_beam_search(**kwargs)
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
cache = kwargs.get("cache", {})
|
||||
if len(cache) == 0:
|
||||
self.init_cache(cache, **kwargs)
|
||||
|
||||
meta_data = {}
|
||||
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
|
||||
chunk_stride_samples = chunk_size[1] * 960 # 600ms
|
||||
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
assert len(audio_sample_list) == 1, "batch_size must be set 1"
|
||||
|
||||
audio_sample = cache["prev_samples"] + audio_sample_list[0]
|
||||
|
||||
n = len(audio_sample) // chunk_stride_samples
|
||||
m = len(audio_sample) % chunk_stride_samples
|
||||
for i in range(n):
|
||||
audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
|
||||
|
||||
# extract fbank feats
|
||||
speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend, cache=cache["frontend"])
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
|
||||
|
||||
cache["prev_samples"] = audio_sample[:-m]
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ def get_readme():
|
||||
|
||||
|
||||
MODULE_NAME = 'funasr_onnx'
|
||||
VERSION_NUM = '0.2.4'
|
||||
VERSION_NUM = '0.2.5'
|
||||
|
||||
setuptools.setup(
|
||||
name=MODULE_NAME,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user