diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py index b358fa379..004201e53 100644 --- a/funasr/datasets/audio_datasets/espnet_samplers.py +++ b/funasr/datasets/audio_datasets/espnet_samplers.py @@ -147,7 +147,9 @@ class EspnetStyleBatchSampler(DistributedSampler): start_idx = self.rank * batches_per_rank end_idx = start_idx + batches_per_rank rank_batches = buffer_batches[start_idx + self.start_step : end_idx] + self.batch_num = len(rank_batches) + logging.info( f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}" ) diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py index 60af29ab8..ff933d77b 100644 --- a/funasr/models/sense_voice/decoder.py +++ b/funasr/models/sense_voice/decoder.py @@ -360,6 +360,7 @@ class SenseVoiceDecoder(nn.Module): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) + logp = torch.log_softmax(logp, dim=-1) return logp.squeeze(0)[-1, :], state diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index 127d5a0a5..22272eefc 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -1264,15 +1264,29 @@ class SenseVoiceSANM(nn.Module): if isinstance(task, str): task = [task] task = "".join([f"<|{x}|>" for x in task]) - initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") + + sos = kwargs.get("model_conf").get("sos") + if isinstance(sos, str): + initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") - language = DecodingOptions.get("language", None) - language = None if language == "auto" else language + language = DecodingOptions.get("language", None) + language = None if language == "auto" else language - sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt - sos_int = tokenizer.encode(sos, allowed_special="all") + sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt + sos_int = tokenizer.encode(sos, allowed_special="all") + else: + language = DecodingOptions.get("language", None) + language = None if language == "auto" else language + initial_prompt = kwargs.get("initial_prompt", f"{task}") + initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt + initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all") + sos_int = [sos] + initial_prompt_lid_int eos = kwargs.get("model_conf").get("eos") - eos_int = tokenizer.encode(eos, allowed_special="all") + if isinstance(eos, str): + eos_int = tokenizer.encode(eos, allowed_special="all") + else: + eos_int = [eos] + self.beam_search.sos = sos_int self.beam_search.eos = eos_int[0] @@ -1298,7 +1312,7 @@ class SenseVoiceSANM(nn.Module): self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1]) encoder_out, encoder_out_lens = self.encode( - speech[None, :, :].permute(0, 2, 1), speech_lengths + speech[None, :, :], speech_lengths ) if text_token_int is not None: