This commit is contained in:
游雁 2024-07-18 13:48:27 +08:00
parent a4bb21b888
commit b03c8a5c35

View File

@ -25,6 +25,7 @@ class ModelDimensions:
n_text_state: int
n_text_head: int
n_text_layer: int
att_type: str
# class LayerNorm(nn.LayerNorm):
@ -140,14 +141,90 @@ class MultiHeadAttention(nn.Module):
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class MultiHeadAttentionSdpa(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
**kwargs,
):
is_pad_mask = kwargs.get("is_pad_mask", False)
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(
q, k, v, mask, is_pad_mask=is_pad_mask, is_causal=True if xa is not None else False
)
return self.out(wv), qk
def qkv_attention(
self,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
**kwargs,
):
is_pad_mask = kwargs.get("is_pad_mask", False)
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.5
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
if mask is not None:
mask = mask.unsqueeze(1).eq(0)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=kwargs.get("is_causal", True),
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
return attn_output, None
att_type_dict = {
"default": MultiHeadAttention,
"sdpa": MultiHeadAttentionSdpa,
}
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
att_type = kwargs.get("att_type", "default")
self.attn = att_type_dict[att_type](n_state, n_head) # MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn = (
att_type_dict[att_type](n_state, n_head) if cross_attention else None
) # MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
@ -177,14 +254,17 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
[
ResidualAttentionBlock(n_state, n_head, att_type=kwargs.get("att_type", "default"))
for _ in range(n_layer)
]
)
self.ln_post = LayerNorm(n_state)
@ -209,14 +289,22 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
[
ResidualAttentionBlock(
n_state,
n_head,
cross_attention=True,
att_type=kwargs.get("att_type", "default"),
)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
@ -253,6 +341,7 @@ class Whisper(nn.Module):
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
att_type=self.dims.att_type,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
@ -260,6 +349,7 @@ class Whisper(nn.Module):
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
att_type=self.dims.att_type,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.