mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #932 from alibaba-damo-academy/dev_lhn
support chunk size select for chunk-hopping encoder
This commit is contained in:
commit
9fcb3cc06b
@ -399,7 +399,7 @@ class Speech2TextParaformer:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||
begin_time: int = 0, end_time: int = None,
|
||||
decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
|
||||
):
|
||||
"""Inference
|
||||
|
||||
@ -429,7 +429,9 @@ class Speech2TextParaformer:
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
|
||||
if decoding_ind is None:
|
||||
decoding_ind = self.decoding_ind
|
||||
enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
|
||||
if isinstance(enc, tuple):
|
||||
enc = enc[0]
|
||||
# assert len(enc) == 1, len(enc)
|
||||
|
||||
@ -236,6 +236,7 @@ def inference_paraformer(
|
||||
timestamp_infer_config: Union[Path, str] = None,
|
||||
timestamp_model_file: Union[Path, str] = None,
|
||||
param_dict: dict = None,
|
||||
decoding_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
@ -290,6 +291,7 @@ def inference_paraformer(
|
||||
nbest=nbest,
|
||||
hotword_list_or_file=hotword_list_or_file,
|
||||
clas_scale=clas_scale,
|
||||
decoding_ind=decoding_ind,
|
||||
)
|
||||
|
||||
speech2text = Speech2TextParaformer(**speech2text_kwargs)
|
||||
@ -312,6 +314,7 @@ def inference_paraformer(
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
decoding_ind = None
|
||||
hotword_list_or_file = None
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
@ -319,6 +322,8 @@ def inference_paraformer(
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
if hotword_list_or_file is not None or 'hotword' in kwargs:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
if param_dict is not None and "decoding_ind" in param_dict:
|
||||
decoding_ind = param_dict["decoding_ind"]
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
@ -365,6 +370,7 @@ def inference_paraformer(
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
|
||||
time_beg = time.time()
|
||||
batch["decoding_ind"] = decoding_ind
|
||||
results = speech2text(**batch)
|
||||
if len(results) < 1:
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
@ -1786,6 +1792,12 @@ def get_parser():
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument(
|
||||
"--decoding_ind",
|
||||
type=int,
|
||||
default=0,
|
||||
help="chunk select for chunk encoder",
|
||||
)
|
||||
group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
|
||||
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
||||
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user