This commit is contained in:
游雁 2024-07-12 11:42:28 +08:00
parent a7fd8f8544
commit b069dba3be

View File

@ -855,37 +855,14 @@ class LLMASR4(nn.Module):
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
@ -1005,12 +982,12 @@ class LLMASR4(nn.Module):
batch_size_speech, frames, _ = speech.shape
batch_size, token_num = input_ids.shape
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
batch_size, token_num, dims = inputs_embeds.shape
fake_token_len = kwargs.get("fake_token_len")