Merge pull request #932 from alibaba-damo-academy/dev_lhn

support chunk size select for chunk-hopping encoder
This commit is contained in:
hnluo 2023-09-11 17:40:03 +08:00 committed by GitHub
commit 9fcb3cc06b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -399,7 +399,7 @@ class Speech2TextParaformer:
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
begin_time: int = 0, end_time: int = None,
decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
):
"""Inference
@ -429,7 +429,9 @@ class Speech2TextParaformer:
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
if decoding_ind is None:
decoding_ind = self.decoding_ind
enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
if isinstance(enc, tuple):
enc = enc[0]
# assert len(enc) == 1, len(enc)

View File

@ -236,6 +236,7 @@ def inference_paraformer(
timestamp_infer_config: Union[Path, str] = None,
timestamp_model_file: Union[Path, str] = None,
param_dict: dict = None,
decoding_ind: int = 0,
**kwargs,
):
ncpu = kwargs.get("ncpu", 1)
@ -290,6 +291,7 @@ def inference_paraformer(
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
clas_scale=clas_scale,
decoding_ind=decoding_ind,
)
speech2text = Speech2TextParaformer(**speech2text_kwargs)
@ -312,6 +314,7 @@ def inference_paraformer(
**kwargs,
):
decoding_ind = None
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
@ -319,6 +322,8 @@ def inference_paraformer(
hotword_list_or_file = kwargs['hotword']
if hotword_list_or_file is not None or 'hotword' in kwargs:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
if param_dict is not None and "decoding_ind" in param_dict:
decoding_ind = param_dict["decoding_ind"]
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
@ -365,6 +370,7 @@ def inference_paraformer(
# N-best list of (text, token, token_int, hyp_object)
time_beg = time.time()
batch["decoding_ind"] = decoding_ind
results = speech2text(**batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
@ -1786,6 +1792,12 @@ def get_parser():
default=1,
help="The batch size for inference",
)
group.add_argument(
"--decoding_ind",
type=int,
default=0,
help="chunk select for chunk encoder",
)
group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")