diff --git a/funasr/models/sense_voice/model_small.py b/funasr/models/sense_voice/model_small.py index 652217170..052069947 100644 --- a/funasr/models/sense_voice/model_small.py +++ b/funasr/models/sense_voice/model_small.py @@ -1,6 +1,6 @@ import time import torch -from torch import nn +from torch import Tensor, nn import torch.nn.functional as F from typing import Iterable, Optional @@ -12,6 +12,7 @@ from funasr.train_utils.device_funcs import force_gatherable from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.metrics.compute_acc import compute_accuracy, th_accuracy from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.models.transformer.utils.nets_utils import make_pad_mask class SinusoidalPositionEncoder(torch.nn.Module): @@ -905,9 +906,937 @@ class SenseVoiceSmall(nn.Module): return results, meta_data def export(self, **kwargs): - from export_meta import export_rebuild_model + from .export_meta import export_rebuild_model if "max_seq_len" not in kwargs: kwargs["max_seq_len"] = 512 models = export_rebuild_model(model=self, **kwargs) return models + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward(self, x, weight, bias): + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + **kwargs, + ): + is_pad_mask = kwargs.get("is_pad_mask", False) + + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask) + return self.out(wv), qk + + def qkv_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + is_pad_mask = kwargs.get("is_pad_mask", False) + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + if not is_pad_mask: + qk = qk + mask[:n_ctx, :n_ctx] + else: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, t, 1) + min_value = -float( + "inf" + ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) + qk = qk.masked_fill(mask, min_value) + + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + if mask is not None and is_pad_mask: + w = w.masked_fill(mask, 0.0) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class MultiHeadAttentionSdpa(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + **kwargs, + ): + is_pad_mask = kwargs.get("is_pad_mask", False) + + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=False) + return self.out(wv), qk + + def qkv_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + is_pad_mask = kwargs.get("is_pad_mask", False) + is_causal = kwargs.get("is_causal", False) + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.5 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + if mask is not None: + if not is_pad_mask: + mask = None + is_causal = True + else: + mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scale, + ) + if mask is not None: + attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(start_dim=2) + return attn_output, None + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as( + self.inv_freq + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MultiHeadAttentionRoPE(nn.Module): + def __init__(self, linear_units: int, attention_heads: int, **kwargs): + super().__init__() + self.attention_heads = attention_heads + self.query = Linear(linear_units, linear_units) + self.key = Linear(linear_units, linear_units, bias=False) + self.value = Linear(linear_units, linear_units) + self.out = Linear(linear_units, linear_units) + self.rotary_emb = RotaryEmbedding( + attention_heads, + max_position_embeddings=kwargs.get("max_position_embeddings", 2048), + base=kwargs.get("rope_theta", 10000), + ) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + + q = self.query(x) + k = self.key(x) + v = self.value(x) + + wv, qk = self.qkv_attention(q, k, v, mask, **kwargs) + return self.out(wv), qk + + def qkv_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + position_ids = kwargs.get("position_ids", None) + kv_seq_len = v.shape[-2] + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + qk = q @ k + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, t, 1) + min_value = -float( + "inf" + ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) + qk = qk.masked_fill(mask, min_value) + + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + if mask is not None: + w = w.masked_fill(mask, 0.0) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class MultiHeadAttentionSdpaRoPE(nn.Module): + def __init__(self, linear_units: int, attention_heads: int, **kwargs): + super().__init__() + self.attention_heads = attention_heads + self.query = Linear(linear_units, linear_units) + self.key = Linear(linear_units, linear_units, bias=False) + self.value = Linear(linear_units, linear_units) + self.out = Linear(linear_units, linear_units) + self.rotary_emb = RotaryEmbedding( + attention_heads, + max_position_embeddings=kwargs.get("max_position_embeddings", 2048), + base=kwargs.get("rope_theta", 10000), + ) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + + q = self.query(x) + k = self.key(x) + v = self.value(x) + + wv, qk = self.qkv_attention(q, k, v, mask, **kwargs) + return self.out(wv), qk + + def qkv_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + + is_causal = kwargs.get("is_causal", False) + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.5 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + position_ids = kwargs.get("position_ids", None) + kv_seq_len = v.shape[-2] + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + if mask is not None: + mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scale, + ) + if mask is not None: + attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(start_dim=2) + return attn_output, None + + +class MultiHeadAttentionFSMNRoPE(nn.Module): + def __init__(self, linear_units: int, attention_heads: int, **kwargs): + super().__init__() + self.attention_heads = attention_heads + self.query = Linear(linear_units, linear_units) + self.key = Linear(linear_units, linear_units, bias=False) + self.value = Linear(linear_units, linear_units) + self.out = Linear(linear_units, linear_units) + self.rotary_emb = RotaryEmbedding( + attention_heads, + max_position_embeddings=kwargs.get("max_position_embeddings", 2048), + base=kwargs.get("rope_theta", 10000), + ) + + self.fsmn_block = nn.Conv1d( + linear_units, + linear_units, + kwargs.get("kernel_size", 15), + stride=1, + padding=0, + groups=linear_units, + bias=False, + ) + # padding + left_padding = (kwargs.get("kernel_size", 15) - 1) // 2 + left_padding = left_padding + kwargs.get("sanm_shfit", 0) + right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def fsmn(self, inputs, mask): + b, t, d = inputs.size() + if mask is not None: + mask = torch.reshape(mask, (b, -1, 1)) + inputs = inputs * mask + + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + inputs + # x = self.dropout(x) + if mask is not None: + x = x * mask + return x + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + q = self.query(x) + k = self.key(x) + v = self.value(x) + + memory = self.fsmn(v, mask=mask) + wv, qk = self.qkv_attention(q, k, v, mask, **kwargs) + return self.out(wv) + memory, qk + + def qkv_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + position_ids = kwargs.get("position_ids", None) + kv_seq_len = v.shape[-2] + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + qk = q @ k + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, t, 1) + min_value = -float( + "inf" + ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) + qk = qk.masked_fill(mask, min_value) + + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + if mask is not None: + w = w.masked_fill(mask, 0.0) + + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class MultiHeadAttentionFSMNSdpaRoPE(nn.Module): + def __init__(self, linear_units: int, attention_heads: int, **kwargs): + super().__init__() + + self.attention_heads = attention_heads + self.query = Linear(linear_units, linear_units) + self.key = Linear(linear_units, linear_units, bias=False) + self.value = Linear(linear_units, linear_units) + self.out = Linear(linear_units, linear_units) + self.rotary_emb = RotaryEmbedding( + attention_heads, + max_position_embeddings=kwargs.get("max_position_embeddings", 2048), + base=kwargs.get("rope_theta", 10000), + ) + + self.fsmn_block = nn.Conv1d( + linear_units, + linear_units, + kwargs.get("kernel_size", 15), + stride=1, + padding=0, + groups=linear_units, + bias=False, + ) + # padding + left_padding = (kwargs.get("kernel_size", 15) - 1) // 2 + left_padding = left_padding + kwargs.get("sanm_shfit", 0) + right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def fsmn(self, inputs, mask): + b, t, d = inputs.size() + if mask is not None: + mask = torch.reshape(mask, (b, -1, 1)) + inputs = inputs * mask + + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + inputs + # x = self.dropout(x) + if mask is not None: + x = x * mask + return x + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + + q = self.query(x) + k = self.key(x) + v = self.value(x) + memory = self.fsmn(v, mask=mask) + + wv, qk = self.qkv_attention(q, k, v, mask, **kwargs) + return self.out(wv) + memory, qk + + def qkv_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + is_causal = kwargs.get("is_causal", False) + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.5 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + position_ids = kwargs.get("position_ids", None) + kv_seq_len = v.shape[-2] + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + if mask is not None: + mask = mask.unsqueeze(1).to(torch.bool) # (batch, 1, 1, t) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + is_causal=is_causal, + scale=scale, + ) + if mask is not None: + attn_output = attn_output.masked_fill(mask.transpose(2, 3).logical_not(), 0.0) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(start_dim=2) + return attn_output, None + + +att_type_dict = { + "default": MultiHeadAttention, + "sdpa": MultiHeadAttentionSdpa, + "self_att": MultiHeadAttentionRoPE, + "self_att_sdpa": MultiHeadAttentionSdpaRoPE, + "self_att_fsmn": MultiHeadAttentionFSMNRoPE, + "self_att_fsmn_sdpa": MultiHeadAttentionFSMNSdpaRoPE, +} + + +class EncoderLayerSANMLarge(nn.Module): + def __init__(self, linear_units: int, attention_heads: int, **kwargs): + super().__init__() + + att_type = kwargs.get("att_type", "self_att_fsmn_sdpa") + self.attn = att_type_dict[att_type](linear_units, attention_heads) + self.attn_ln = LayerNorm(linear_units) + + n_mlp = linear_units * 4 + self.mlp = nn.Sequential( + Linear(linear_units, n_mlp), nn.GELU(), Linear(n_mlp, linear_units) + ) + self.mlp_ln = LayerNorm(linear_units) + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + **kwargs, + ): + is_pad_mask = kwargs.get("is_pad_mask", False) + + x = x + self.attn(self.attn_ln(x), mask=mask, is_pad_mask=is_pad_mask)[0] + + x = x + self.mlp(self.mlp_ln(x)) + return x + + +@tables.register("encoder_classes", "SenseVoiceEncoder") +class SenseVoiceEncoder(nn.Module): + def __init__( + self, + input_size, + n_ctx: int, + linear_units: int, + attention_heads: int, + num_blocks: int, + **kwargs, + ): + super().__init__() + self.conv1 = Conv1d(input_size, linear_units, kernel_size=3, stride=2, padding=1) + self.conv2 = Conv1d(linear_units, linear_units, kernel_size=3, stride=2, padding=1) + + self.blocks = nn.ModuleList( + [ + EncoderLayerSANMLarge( + linear_units, attention_heads, att_type=kwargs.get("att_type", "default") + ) + for _ in range(num_blocks) + ] + ) + self.ln_post = LayerNorm(linear_units) + self.use_padmask = kwargs.get("use_padmask", True) + self.downsample_rate = kwargs.get("downsample_rate", 4) + + def forward( + self, + x: torch.Tensor, + ilens: torch.Tensor = None, + **kwargs, + ): + use_padmask = self.use_padmask + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + n_frames = x.size(1) + max_pos = n_frames + # max_pos = self.positional_embedding.size(0) + # max_pos = n_frames if n_frames < max_pos else max_pos + # x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype) + + if ilens is not None: + if self.downsample_rate == 4: + olens = ( + 1 + + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0]) + // self.conv1.stride[0] + ) + else: + olens = ilens + olens = ( + 1 + + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0]) + // self.conv2.stride[0] + ) + olens = torch.clamp(olens, max=max_pos) + else: + olens = None + + if use_padmask and olens is not None: + padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device) + else: + padding_mask = None + + for layer, block in enumerate(self.blocks): + x = block(x, mask=padding_mask, is_pad_mask=True) + + x = self.ln_post(x) + + if ilens is None: + return x + else: + return x, olens + + +import types +import time +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn +from torch.cuda.amp import autocast +from funasr.metrics.compute_acc import compute_accuracy, th_accuracy +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.train_utils.device_funcs import force_gatherable +from . import whisper_lib as whisper +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils.datadir_writer import DatadirWriter + + +@tables.register("model_classes", "SenseVoiceL") +class SenseVoiceL(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + encoder = kwargs.get("kwargs") + encoder_conf = kwargs.get("encoder_conf", {}) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(**encoder_conf) + encoder_output_size = encoder.output_size() + + dims = kwargs.get("dims", {}) + dims = whisper.model.ModelDimensions(**dims) + model = whisper.model.Whisper(dims=dims) + + # encoder + del model.encoder + model.encoder = encoder + + # decoder + model.decoder.use_padmask = kwargs.get("use_padmask", True) + from .decoder import sense_voice_decode_forward + + model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder) + + self.model = model + + self.encoder_output_size = self.model.dims.n_audio_state + + self.activation_checkpoint = kwargs.get("activation_checkpoint", False) + self.ignore_id = kwargs.get("ignore_id", -1) + self.vocab_size = kwargs.get("vocab_size", -1) + self.length_normalized_loss = kwargs.get("length_normalized_loss", True) + self.criterion_att = LabelSmoothingLoss( + size=self.vocab_size, + padding_idx=self.ignore_id, + smoothing=kwargs.get("lsm_weight", 0.0), + normalize_length=self.length_normalized_loss, + ) + + specaug = kwargs.get("specaug", None) + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**kwargs.get("specaug_conf", {})) + self.specaug = specaug + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ): + target_mask = kwargs.get("target_mask", None) + + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + if self.activation_checkpoint: + from torch.utils.checkpoint import checkpoint + + encoder_out, encoder_out_lens = checkpoint( + self.encode, speech, speech_lengths, use_reentrant=False + ) + else: + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask + ) + loss = loss_att + stats = {} + stats["acc"] = acc_att + stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + **kwargs, + ): + """Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Forward encoder + encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) + + return encoder_out, encoder_out_lens + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + **kwargs, + ): + target_mask = kwargs.get("target_mask", None) + stats = {} + + # 1. Forward decoder + decoder_out = self.model.decoder( + x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens + ) + + # 2. Compute attention loss + mask = torch.ones_like(ys_pad) * (-1) + ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) + ys_pad_mask[ys_pad_mask == 0] = -1 + loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) + + with torch.no_grad(): + preds = torch.argmax(decoder_out, -1) + acc_att = compute_accuracy( + preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id + ) + + return loss_att, acc_att, None, None + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + if frontend is None and not hasattr(self, "frontend"): + frontend_class = tables.frontend_classes.get("WhisperFrontend") + frontend = frontend_class( + n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) + ) + self.frontend = frontend + else: + frontend = frontend if frontend is not None else self.frontend + + meta_data = {} + if ( + isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" + ): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video( + data_in, + fs=frontend.fs if hasattr(frontend, "fs") else 16000, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + ) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank( + audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend + ) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 + lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 + meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 + + speech = speech.to(device=kwargs["device"])[0, :, :] + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + DecodingOptions = kwargs.get("DecodingOptions", {}) + task = DecodingOptions.get("task", "ASR") + if isinstance(task, str): + task = [task] + task = "".join([f"<|{x}|>" for x in task]) + initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") + DecodingOptions["initial_prompt"] = initial_prompt + + language = DecodingOptions.get("language", None) + language = None if language == "auto" else language + DecodingOptions["language"] = language + + DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) + + if "without_timestamps" not in DecodingOptions: + DecodingOptions["without_timestamps"] = True + + options = whisper.DecodingOptions(**DecodingOptions) + + result = whisper.decode(self.model, speech, options) + text = f"{result.text}" + results = [] + result_i = {"key": key[0], "text": text} + + results.append(result_i) + + return results, meta_data