diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py index 9b79ed28a..c93988328 100644 --- a/funasr/models/llm_asr/adaptor.py +++ b/funasr/models/llm_asr/adaptor.py @@ -83,25 +83,27 @@ class Transformer(nn.Module): from funasr.models.transformer.attention import MultiHeadedAttention from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward - self.blocks = nn.ModuleList( - [ - EncoderLayer( - llm_dim, - MultiHeadedAttention( - kwargs.get("attention_heads", 8), + self.blocks = None + if kwargs.get("n_layer", 2) > 0: + self.blocks = nn.ModuleList( + [ + EncoderLayer( llm_dim, - kwargs.get("attention_dropout_rate", 0.0), - ), - PositionwiseFeedForward( - llm_dim, - llm_dim // 4, + MultiHeadedAttention( + kwargs.get("attention_heads", 8), + llm_dim, + kwargs.get("attention_dropout_rate", 0.0), + ), + PositionwiseFeedForward( + llm_dim, + llm_dim // 4, + kwargs.get("dropout_rate", 0.0), + ), kwargs.get("dropout_rate", 0.0), - ), - kwargs.get("dropout_rate", 0.0), - ) - for i in range(kwargs.get("n_layer", 2)) - ] - ) + ) + for i in range(kwargs.get("n_layer", 2)) + ] + ) def forward(self, x, ilens=None): @@ -123,6 +125,7 @@ class Transformer(nn.Module): olens = None olens = (ilens - 1) // self.k + 1 masks = (~make_pad_mask(olens)[:, None, :]).to(x.device) - for layer, block in enumerate(self.blocks): - x, masks = block(x, masks) + if self.blocks is not None: + for layer, block in enumerate(self.blocks): + x, masks = block(x, masks) return x, olens diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index b13912346..d94058cca 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -481,7 +481,7 @@ class LLMASR2(nn.Module): batch_size, token_num, dims = inputs_embeds.shape fbank_mask[fbank_mask < 0] = 0 - fbank_fake_lens = fbank_mask.sum(-1) + fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32) # _, l, _ = encoder_out.shape for batch_idx in range(batch_size):