"""RWKV encoder definition for Transducer models.""" import math from typing import Dict, List, Optional, Tuple import torch from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.rwkv import RWKV from funasr.modules.layer_norm import LayerNorm from funasr.modules.rwkv_subsampling import RWKVConvInput from funasr.modules.nets_utils import make_source_mask class RWKVEncoder(AbsEncoder): """RWKV encoder module. Based on https://arxiv.org/pdf/2305.13048.pdf. Args: vocab_size: Vocabulary size. output_size: Input/Output size. context_size: Context size for WKV computation. linear_size: FeedForward hidden size. attention_size: SelfAttention hidden size. normalization_type: Normalization layer type. normalization_args: Normalization layer arguments. num_blocks: Number of RWKV blocks. embed_dropout_rate: Dropout rate for embedding layer. att_dropout_rate: Dropout rate for the attention module. ffn_dropout_rate: Dropout rate for the feed-forward module. """ def __init__( self, input_size: int, output_size: int = 512, context_size: int = 1024, linear_size: Optional[int] = None, attention_size: Optional[int] = None, num_blocks: int = 4, att_dropout_rate: float = 0.0, ffn_dropout_rate: float = 0.0, dropout_rate: float = 0.0, subsampling_factor: int =4, time_reduction_factor: int = 1, kernel: int = 3, ) -> None: """Construct a RWKVEncoder object.""" super().__init__() self.embed = RWKVConvInput( input_size, [output_size//4, output_size//2, output_size], subsampling_factor, conv_kernel_size=kernel, output_size=output_size, ) self.subsampling_factor = subsampling_factor linear_size = output_size * 4 if linear_size is None else linear_size attention_size = output_size if attention_size is None else attention_size self.rwkv_blocks = torch.nn.ModuleList( [ RWKV( output_size, linear_size, attention_size, context_size, block_id, num_blocks, att_dropout_rate=att_dropout_rate, ffn_dropout_rate=ffn_dropout_rate, dropout_rate=dropout_rate, ) for block_id in range(num_blocks) ] ) self.embed_norm = LayerNorm(output_size) self.final_norm = LayerNorm(output_size) self._output_size = output_size self.context_size = context_size self.num_blocks = num_blocks self.time_reduction_factor = time_reduction_factor def output_size(self) -> int: return self._output_size def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: """Encode source label sequences. Args: x: Encoder input sequences. (B, L) Returns: out: Encoder output sequences. (B, U, D) """ _, length, _ = x.size() assert ( length <= self.context_size * self.subsampling_factor ), "Context size is too short for current length: %d versus %d" % ( length, self.context_size * self.subsampling_factor, ) mask = make_source_mask(x_len).to(x.device) x, mask = self.embed(x, mask, None) x = self.embed_norm(x) olens = mask.eq(0).sum(1) for block in self.rwkv_blocks: x, _ = block(x) # for streaming inference # xs_pad = self.rwkv_infer(xs_pad) x = self.final_norm(x) if self.time_reduction_factor > 1: x = x[:,::self.time_reduction_factor,:] olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 return x, olens, None def rwkv_infer(self, xs_pad): batch_size = xs_pad.shape[0] hidden_sizes = [ self._output_size for i in range(5) ] state = [ torch.zeros( (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), dtype=torch.float32, device=self.device, ) for i in range(5) ] state[4] -= 1e-30 xs_out = [] for t in range(xs_pad.shape[1]): x_t = xs_pad[:,t,:] for idx, block in enumerate(self.rwkv_blocks): x_t, state = block(x_t, state=state) xs_out.append(x_t) xs_out = torch.stack(xs_out, dim=1) return xs_out