mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1785)
* resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * sensevoice sanm --------- Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
This commit is contained in:
parent
79b09f1d67
commit
db9ec58cb4
@ -147,7 +147,9 @@ class EspnetStyleBatchSampler(DistributedSampler):
|
||||
start_idx = self.rank * batches_per_rank
|
||||
end_idx = start_idx + batches_per_rank
|
||||
rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
|
||||
|
||||
self.batch_num = len(rank_batches)
|
||||
|
||||
logging.info(
|
||||
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
|
||||
)
|
||||
|
||||
@ -360,6 +360,7 @@ class SenseVoiceDecoder(nn.Module):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||
logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
|
||||
logp = torch.log_softmax(logp, dim=-1)
|
||||
return logp.squeeze(0)[-1, :], state
|
||||
|
||||
|
||||
|
||||
@ -1264,15 +1264,29 @@ class SenseVoiceSANM(nn.Module):
|
||||
if isinstance(task, str):
|
||||
task = [task]
|
||||
task = "".join([f"<|{x}|>" for x in task])
|
||||
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
||||
|
||||
sos = kwargs.get("model_conf").get("sos")
|
||||
if isinstance(sos, str):
|
||||
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
||||
|
||||
language = DecodingOptions.get("language", None)
|
||||
language = None if language == "auto" else language
|
||||
language = DecodingOptions.get("language", None)
|
||||
language = None if language == "auto" else language
|
||||
|
||||
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||
sos_int = tokenizer.encode(sos, allowed_special="all")
|
||||
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||
sos_int = tokenizer.encode(sos, allowed_special="all")
|
||||
else:
|
||||
language = DecodingOptions.get("language", None)
|
||||
language = None if language == "auto" else language
|
||||
initial_prompt = kwargs.get("initial_prompt", f"{task}")
|
||||
initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||
initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
|
||||
sos_int = [sos] + initial_prompt_lid_int
|
||||
eos = kwargs.get("model_conf").get("eos")
|
||||
eos_int = tokenizer.encode(eos, allowed_special="all")
|
||||
if isinstance(eos, str):
|
||||
eos_int = tokenizer.encode(eos, allowed_special="all")
|
||||
else:
|
||||
eos_int = [eos]
|
||||
|
||||
self.beam_search.sos = sos_int
|
||||
self.beam_search.eos = eos_int[0]
|
||||
|
||||
@ -1298,7 +1312,7 @@ class SenseVoiceSANM(nn.Module):
|
||||
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
|
||||
|
||||
encoder_out, encoder_out_lens = self.encode(
|
||||
speech[None, :, :].permute(0, 2, 1), speech_lengths
|
||||
speech[None, :, :], speech_lengths
|
||||
)
|
||||
|
||||
if text_token_int is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user