From 7584bbd6f3e321cc8bc970739a7cfce29ffcc18b Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Thu, 27 Apr 2023 00:21:20 +0800 Subject: [PATCH] update paraformer streaming code --- .../bin/asr_inference_paraformer_streaming.py | 393 +++++------------- funasr/models/e2e_asr_paraformer.py | 4 +- funasr/models/encoder/sanm_encoder.py | 21 +- funasr/models/predictor/cif.py | 128 +++--- funasr/modules/embedding.py | 13 +- 5 files changed, 196 insertions(+), 363 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 821f69429..939ffe99f 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -19,7 +19,6 @@ from typing import List import numpy as np import torch -import torchaudio from typeguard import check_argument_types from funasr.fileio.datadir_writer import DatadirWriter @@ -40,11 +39,12 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export + np.set_printoptions(threshold=np.inf) + class Speech2Text: """Speech2Text class @@ -89,7 +89,7 @@ class Speech2Text: ) frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -189,8 +189,7 @@ class Speech2Text: @torch.no_grad() def __call__( - self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, - begin_time: int = 0, end_time: int = None, + self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None ): """Inference @@ -201,38 +200,57 @@ class Speech2Text: """ assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None + results = [] + cache_en = cache["encoder"] + if speech.shape[1] < 16 * 60 and cache["is_final"]: + cache["last_chunk"] = True + feats = cache["feats"] + feats_len = torch.tensor([feats.shape[1]]) else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - feats_len = cache["encoder"]["stride"] + cache["encoder"]["pad_left"] + cache["encoder"]["pad_right"] - feats = feats[:,cache["encoder"]["start_idx"]:cache["encoder"]["start_idx"]+feats_len,:] - feats_len = torch.tensor([feats_len]) - batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache} + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"]) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths - # a. To device + if feats.shape[1] != 0: + if cache_en["is_final"]: + if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]: + cache_en["last_chunk"] = True + else: + # first chunk + feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :] + feats_len = torch.tensor([feats_chunk1.shape[1]]) + results_chunk1 = self.infer(feats_chunk1, feats_len, cache) + + # last chunk + cache_en["last_chunk"] = True + feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :] + feats_len = torch.tensor([feats_chunk2.shape[1]]) + results_chunk2 = self.infer(feats_chunk2, feats_len, cache) + + return results_chunk1 + results_chunk2 + + results = self.infer(feats, feats_len, cache) + + return results + + @torch.no_grad() + def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None): + batch = {"speech": feats, "speech_lengths": feats_len} batch = to_device(batch, device=self.device) - # b. Forward Encoder - enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache) + enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache) if isinstance(enc, tuple): enc = enc[0] # assert len(enc) == 1, len(enc) enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.floor().long() + pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1] if torch.max(pre_token_length) < 1: return [] decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache) @@ -279,166 +297,12 @@ class Speech2Text: text = self.tokenizer.tokens2text(token) else: text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + results.append(text) # assert check_return_type(results) return results -class Speech2TextExport: - """Speech2TextExport class - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): - - # 1. Build ASR model - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() - - token_list = asr_model.token_list - - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - # self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - - model = Paraformer_export(asr_model, onnx=False) - self.asr_model = model - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference - - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - - enc_len_batch_total = feats_len.sum() - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - decoder_outs = self.asr_model(**batch) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - am_scores = decoder_out[i, :ys_pad_lens[i], :] - - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - yseq.tolist(), device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) - - return results - - def inference( maxlenratio: float, minlenratio: float, @@ -536,8 +400,6 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() - ncpu = kwargs.get("ncpu", 1) - torch.set_num_threads(ncpu) if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") @@ -580,11 +442,9 @@ def inference_modelscope( penalty=penalty, nbest=nbest, ) - if export_mode: - speech2text = Speech2TextExport(**speech2text_kwargs) - else: - speech2text = Speech2Text(**speech2text_kwargs) - + + speech2text = Speech2Text(**speech2text_kwargs) + def _load_bytes(input): middle_data = np.frombuffer(input, dtype=np.int16) middle_data = np.asarray(middle_data) @@ -599,7 +459,33 @@ def inference_modelscope( offset = i.min + abs_max array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array - + + def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + return cache + + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + + def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + def _forward( data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, @@ -610,123 +496,35 @@ def inference_modelscope( ): # 3. Build data-iterator + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": + raw_inputs = _load_bytes(data_path_and_name_and_type[0]) + raw_inputs = torch.tensor(raw_inputs) + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, np.ndarray): + raw_inputs = torch.tensor(raw_inputs) is_final = False cache = {} + chunk_size = [5, 10, 5] if param_dict is not None and "cache" in param_dict: cache = param_dict["cache"] if param_dict is not None and "is_final" in param_dict: is_final = param_dict["is_final"] + if param_dict is not None and "chunk_size" in param_dict: + chunk_size = param_dict["chunk_size"] - if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": - raw_inputs = _load_bytes(data_path_and_name_and_type[0]) - raw_inputs = torch.tensor(raw_inputs) - if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": - raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] - is_final = True - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, np.ndarray): - raw_inputs = torch.tensor(raw_inputs) # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about + raw_inputs = torch.unsqueeze(raw_inputs, axis=0) + input_lens = torch.tensor([raw_inputs.shape[1]]) asr_result_list = [] - results = [] - asr_result = "" - wait = True - if len(cache) == 0: - cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None, "is_final": is_final, "left": 0, "right": 0} - cache_de = {"decode_fsmn": None} - cache["decoder"] = cache_de - cache["first_chunk"] = True - cache["speech"] = [] - cache["accum_speech"] = 0 - if raw_inputs is not None: - if len(cache["speech"]) == 0: - cache["speech"] = raw_inputs - else: - cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0) - cache["accum_speech"] += len(raw_inputs) - while cache["accum_speech"] >= 960: - if cache["first_chunk"]: - if cache["accum_speech"] >= 14400: - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["stride"] = 10 - cache["encoder"]["left"] = 5 - cache["encoder"]["right"] = 0 - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 4800 - cache["first_chunk"] = False - cache["encoder"]["start_idx"] = -5 - cache["encoder"]["is_final"] = False - wait = False - else: - if is_final: - cache["encoder"]["stride"] = len(cache["speech"]) // 960 - cache["encoder"]["pad_left"] = 0 - cache["encoder"]["pad_right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] = 0 - wait = False - else: - break - else: - if cache["accum_speech"] >= 19200: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = 10 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 9600 - wait = False - else: - if is_final: - cache["encoder"]["is_final"] = True - if cache["accum_speech"] >= 14400: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = 10 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = cache["accum_speech"] // 960 - 15 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 9600 - wait = False - else: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = cache["accum_speech"] // 960 - 5 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 0 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] = 0 - wait = False - else: - break - - if len(results) >= 1: - asr_result += results[0][0] - if asr_result == "": - asr_result = "sil" - if wait: - asr_result = "waiting_for_more_voice" - item = {'key': "utt", 'value': asr_result} - asr_result_list.append(item) - else: - return [] + cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + cache["encoder"]["is_final"] = is_final + asr_result = speech2text(cache, raw_inputs, input_lens) + item = {'key': "utt", 'value': asr_result} + asr_result_list.append(item) + if is_final: + cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1) return asr_result_list return _forward @@ -921,4 +719,3 @@ if __name__ == "__main__": # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') # print(rec_result) - diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 699d85fdb..d02783f49 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -712,9 +712,9 @@ class ParaformerOnline(Paraformer): def calc_predictor_chunk(self, encoder_out, cache=None): - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \ + pre_acoustic_embeds, pre_token_length = \ self.predictor.forward_chunk(encoder_out, cache["encoder"]) - return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + return pre_acoustic_embeds, pre_token_length def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): decoder_outs = self.decoder.forward_chunk( diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index f2502bbb6..7d84ad5f1 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -6,9 +6,11 @@ from typing import Union import logging import torch import torch.nn as nn +import torch.nn.functional as F from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk from typeguard import check_argument_types import numpy as np +from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder @@ -349,6 +351,23 @@ class SANMEncoder(AbsEncoder): return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): + if len(cache) == 0: + return feats + # process last chunk + cache["feats"] = to_device(cache["feats"], device=feats.device) + overlap_feats = torch.cat((cache["feats"], feats), dim=1) + if cache["is_final"]: + cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :] + if not cache["last_chunk"]: + padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1] + overlap_feats = overlap_feats.transpose(1, 2) + overlap_feats = F.pad(overlap_feats, (0, padding_length)) + overlap_feats = overlap_feats.transpose(1, 2) + else: + cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] + return overlap_feats + def forward_chunk(self, xs_pad: torch.Tensor, ilens: torch.Tensor, @@ -360,7 +379,7 @@ class SANMEncoder(AbsEncoder): xs_pad = xs_pad else: xs_pad = self.embed(xs_pad, cache) - + xs_pad = self._add_overlap_chunk(xs_pad, cache) encoder_outs = self.encoders0(xs_pad, None, None, None, None) xs_pad, masks = encoder_outs[0], encoder_outs[1] intermediate_outs = [] diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index a5273f841..c59e24502 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -2,6 +2,7 @@ import torch from torch import nn import logging import numpy as np +from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.streaming_utils.utils import sequence_mask @@ -200,7 +201,7 @@ class CifPredictorV2(nn.Module): return acoustic_embeds, token_num, alphas, cif_peak def forward_chunk(self, hidden, cache=None): - b, t, d = hidden.size() + batch_size, len_time, hidden_size = hidden.shape h = hidden context = h.transpose(1, 2) queries = self.pad(context) @@ -211,58 +212,81 @@ class CifPredictorV2(nn.Module): alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) alphas = alphas.squeeze(-1) - mask_chunk_predictor = None - if cache is not None: - mask_chunk_predictor = None - mask_chunk_predictor = torch.zeros_like(alphas) - mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0 - - if mask_chunk_predictor is not None: - alphas = alphas * mask_chunk_predictor - - if cache is not None: - if cache["is_final"]: - alphas[:, cache["stride"] + cache["pad_left"] - 1] += 0.45 - if cache["cif_hidden"] is not None: - hidden = torch.cat((cache["cif_hidden"], hidden), 1) - if cache["cif_alphas"] is not None: - alphas = torch.cat((cache["cif_alphas"], alphas), -1) - token_num = alphas.sum(-1) - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) - len_time = alphas.size(-1) - last_fire_place = len_time - 1 - last_fire_remainds = 0.0 - pre_alphas_length = 0 - last_fire = False - - mask_chunk_peak_predictor = None - if cache is not None: - mask_chunk_peak_predictor = None - mask_chunk_peak_predictor = torch.zeros_like(cif_peak) - if cache["cif_alphas"] is not None: - pre_alphas_length = cache["cif_alphas"].size(-1) - mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0 - mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0 - - if mask_chunk_peak_predictor is not None: - cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1) - - for i in range(len_time): - if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold: - last_fire_place = len_time - 1 - i - last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold - last_fire = True - break - if last_fire: - last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device) - cache["cif_hidden"] = hidden[:, last_fire_place:, :] - cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1) - else: - cache["cif_hidden"] = hidden - cache["cif_alphas"] = alphas - token_num_int = token_num.floor().type(torch.int32).item() - return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak + token_length = [] + list_fires = [] + list_frames = [] + cache_alphas = [] + cache_hiddens = [] + + if cache is not None and "chunk_size" in cache: + alphas[:, :cache["chunk_size"][0]] = 0.0 + alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 + if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: + cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) + cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) + hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) + alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) + if cache is not None and "last_chunk" in cache and cache["last_chunk"]: + tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) + tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) + tail_alphas = torch.tile(tail_alphas, (batch_size, 1)) + hidden = torch.cat((hidden, tail_hidden), dim=1) + alphas = torch.cat((alphas, tail_alphas), dim=1) + + len_time = alphas.shape[1] + for b in range(batch_size): + integrate = 0.0 + frames = torch.zeros((hidden_size), device=hidden.device) + list_frame = [] + list_fire = [] + for t in range(len_time): + alpha = alphas[b][t] + if alpha + integrate < self.threshold: + integrate += alpha + list_fire.append(integrate) + frames += alpha * hidden[b][t] + else: + frames += (self.threshold - integrate) * hidden[b][t] + list_frame.append(frames) + integrate += alpha + list_fire.append(integrate) + integrate -= self.threshold + frames = integrate * hidden[b][t] + + cache_alphas.append(integrate) + if integrate > 0.0: + cache_hiddens.append(frames / integrate) + else: + cache_hiddens.append(frames) + + token_length.append(torch.tensor(len(list_frame), device=alphas.device)) + list_fires.append(list_fire) + list_frames.append(list_frame) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + + max_token_len = max(token_length) + if max_token_len == 0: + return hidden, torch.stack(token_length, 0) + list_ls = [] + for b in range(batch_size): + pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device) + if token_length[b] == 0: + list_ls.append(pad_frames) + else: + list_frames[b] = torch.stack(list_frames[b]) + list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0)) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + return torch.stack(list_ls, 0), torch.stack(token_length, 0) + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): b, t, d = hidden.size() diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index c347e24f1..aaac80a7d 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -425,21 +425,14 @@ class StreamSinusoidalPositionEncoder(torch.nn.Module): return encoding.type(dtype) def forward(self, x, cache=None): - start_idx = 0 - pad_left = 0 - pad_right = 0 batch_size, timesteps, input_dim = x.size() + start_idx = 0 if cache is not None: start_idx = cache["start_idx"] - pad_left = cache["left"] - pad_right = cache["right"] + cache["start_idx"] += timesteps positions = torch.arange(1, timesteps+start_idx+1)[None, :] position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) - outputs = x + position_encoding[:, start_idx: start_idx + timesteps] - outputs = outputs.transpose(1, 2) - outputs = F.pad(outputs, (pad_left, pad_right)) - outputs = outputs.transpose(1, 2) - return outputs + return x + position_encoding[:, start_idx: start_idx + timesteps] class StreamingRelPositionalEncoding(torch.nn.Module): """Relative positional encoding.