mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
a4bb21b888
commit
b03c8a5c35
@ -25,6 +25,7 @@ class ModelDimensions:
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
att_type: str
|
||||
|
||||
|
||||
# class LayerNorm(nn.LayerNorm):
|
||||
@ -140,14 +141,90 @@ class MultiHeadAttention(nn.Module):
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
class MultiHeadAttentionSdpa(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(
|
||||
q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=True if xa is not None else False
|
||||
)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
self,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.5
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=kwargs.get("is_causal", True),
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
att_type_dict = {
|
||||
"default": MultiHeadAttention,
|
||||
"sdpa": MultiHeadAttentionSdpa,
|
||||
}
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
att_type = kwargs.get("att_type", "default")
|
||||
self.attn = att_type_dict[att_type](n_state, n_head) # MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn = (
|
||||
att_type_dict[att_type](n_state, n_head) if cross_attention else None
|
||||
) # MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
@ -177,14 +254,17 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, att_type=kwargs.get("att_type", "default"))
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
@ -209,14 +289,22 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
n_state,
|
||||
n_head,
|
||||
cross_attention=True,
|
||||
att_type=kwargs.get("att_type", "default"),
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
@ -253,6 +341,7 @@ class Whisper(nn.Module):
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
att_type=self.dims.att_type,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
@ -260,6 +349,7 @@ class Whisper(nn.Module):
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
att_type=self.dims.att_type,
|
||||
)
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user