mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
16001677ac
commit
5dd4495406
@ -1170,7 +1170,7 @@ class MultiHeadAttentionRoPE(nn.Module):
|
||||
self.value = Linear(linear_units, linear_units)
|
||||
self.out = Linear(linear_units, linear_units)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
attention_heads,
|
||||
linear_units // attention_heads,
|
||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||
base=kwargs.get("rope_theta", 10000),
|
||||
)
|
||||
@ -1234,7 +1234,7 @@ class MultiHeadAttentionSdpaRoPE(nn.Module):
|
||||
self.value = Linear(linear_units, linear_units)
|
||||
self.out = Linear(linear_units, linear_units)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
attention_heads,
|
||||
linear_units // attention_heads,
|
||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||
base=kwargs.get("rope_theta", 10000),
|
||||
)
|
||||
@ -1302,7 +1302,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
|
||||
self.value = Linear(linear_units, linear_units)
|
||||
self.out = Linear(linear_units, linear_units)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
attention_heads,
|
||||
linear_units // attention_heads,
|
||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||
base=kwargs.get("rope_theta", 10000),
|
||||
)
|
||||
@ -1360,11 +1360,11 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
b, t, d = q.shape
|
||||
scale = (d // self.attention_heads) ** -0.5
|
||||
q = q.view(*q.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
kv_seq_len = v.shape[-2]
|
||||
@ -1398,7 +1398,7 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module):
|
||||
self.value = Linear(linear_units, linear_units)
|
||||
self.out = Linear(linear_units, linear_units)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
attention_heads,
|
||||
linear_units // attention_heads,
|
||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||
base=kwargs.get("rope_theta", 10000),
|
||||
)
|
||||
@ -1457,11 +1457,11 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
is_causal = kwargs.get("is_causal", False)
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.5
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
b, t, d = q.shape
|
||||
scale = (d // self.attention_heads) ** -0.5
|
||||
q = q.view(*q.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
kv_seq_len = v.shape[-2]
|
||||
@ -1517,9 +1517,8 @@ class EncoderLayerSANMLarge(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, is_pad_mask=is_pad_mask)[0]
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, **kwargs)[0]
|
||||
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
@ -1562,9 +1561,6 @@ class SenseVoiceEncoder(nn.Module):
|
||||
|
||||
n_frames = x.size(1)
|
||||
max_pos = n_frames
|
||||
# max_pos = self.positional_embedding.size(0)
|
||||
# max_pos = n_frames if n_frames < max_pos else max_pos
|
||||
# x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
|
||||
|
||||
if ilens is not None:
|
||||
if self.downsample_rate == 4:
|
||||
@ -1589,8 +1585,13 @@ class SenseVoiceEncoder(nn.Module):
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
device = x.device
|
||||
seq_length = x.shape[1]
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, mask=padding_mask, is_pad_mask=True)
|
||||
x = block(x, mask=padding_mask, position_ids=position_ids)
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
@ -1625,7 +1626,6 @@ class SenseVoiceL(nn.Module):
|
||||
encoder_conf = kwargs.get("encoder_conf", {})
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(**encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
dims = kwargs.get("dims", {})
|
||||
dims = whisper.model.ModelDimensions(**dims)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user