mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
16001677ac
commit
5dd4495406
@ -1170,7 +1170,7 @@ class MultiHeadAttentionRoPE(nn.Module):
|
|||||||
self.value = Linear(linear_units, linear_units)
|
self.value = Linear(linear_units, linear_units)
|
||||||
self.out = Linear(linear_units, linear_units)
|
self.out = Linear(linear_units, linear_units)
|
||||||
self.rotary_emb = RotaryEmbedding(
|
self.rotary_emb = RotaryEmbedding(
|
||||||
attention_heads,
|
linear_units // attention_heads,
|
||||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||||
base=kwargs.get("rope_theta", 10000),
|
base=kwargs.get("rope_theta", 10000),
|
||||||
)
|
)
|
||||||
@ -1234,7 +1234,7 @@ class MultiHeadAttentionSdpaRoPE(nn.Module):
|
|||||||
self.value = Linear(linear_units, linear_units)
|
self.value = Linear(linear_units, linear_units)
|
||||||
self.out = Linear(linear_units, linear_units)
|
self.out = Linear(linear_units, linear_units)
|
||||||
self.rotary_emb = RotaryEmbedding(
|
self.rotary_emb = RotaryEmbedding(
|
||||||
attention_heads,
|
linear_units // attention_heads,
|
||||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||||
base=kwargs.get("rope_theta", 10000),
|
base=kwargs.get("rope_theta", 10000),
|
||||||
)
|
)
|
||||||
@ -1302,7 +1302,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
|
|||||||
self.value = Linear(linear_units, linear_units)
|
self.value = Linear(linear_units, linear_units)
|
||||||
self.out = Linear(linear_units, linear_units)
|
self.out = Linear(linear_units, linear_units)
|
||||||
self.rotary_emb = RotaryEmbedding(
|
self.rotary_emb = RotaryEmbedding(
|
||||||
attention_heads,
|
linear_units // attention_heads,
|
||||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||||
base=kwargs.get("rope_theta", 10000),
|
base=kwargs.get("rope_theta", 10000),
|
||||||
)
|
)
|
||||||
@ -1360,11 +1360,11 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
n_batch, n_ctx, n_state = q.shape
|
b, t, d = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.25
|
scale = (d // self.attention_heads) ** -0.5
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
q = q.view(*q.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
k = k.view(*k.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(*v.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
kv_seq_len = v.shape[-2]
|
kv_seq_len = v.shape[-2]
|
||||||
@ -1398,7 +1398,7 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module):
|
|||||||
self.value = Linear(linear_units, linear_units)
|
self.value = Linear(linear_units, linear_units)
|
||||||
self.out = Linear(linear_units, linear_units)
|
self.out = Linear(linear_units, linear_units)
|
||||||
self.rotary_emb = RotaryEmbedding(
|
self.rotary_emb = RotaryEmbedding(
|
||||||
attention_heads,
|
linear_units // attention_heads,
|
||||||
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
max_position_embeddings=kwargs.get("max_position_embeddings", 2048),
|
||||||
base=kwargs.get("rope_theta", 10000),
|
base=kwargs.get("rope_theta", 10000),
|
||||||
)
|
)
|
||||||
@ -1457,11 +1457,11 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
is_causal = kwargs.get("is_causal", False)
|
is_causal = kwargs.get("is_causal", False)
|
||||||
n_batch, n_ctx, n_state = q.shape
|
b, t, d = q.shape
|
||||||
scale = (n_state // self.n_head) ** -0.5
|
scale = (d // self.attention_heads) ** -0.5
|
||||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
q = q.view(*q.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
k = k.view(*k.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
v = v.view(*v.shape[:2], self.attention_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
kv_seq_len = v.shape[-2]
|
kv_seq_len = v.shape[-2]
|
||||||
@ -1517,9 +1517,8 @@ class EncoderLayerSANMLarge(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
|
||||||
|
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, is_pad_mask=is_pad_mask)[0]
|
x = x + self.attn(self.attn_ln(x), mask=mask, **kwargs)[0]
|
||||||
|
|
||||||
x = x + self.mlp(self.mlp_ln(x))
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
return x
|
return x
|
||||||
@ -1562,9 +1561,6 @@ class SenseVoiceEncoder(nn.Module):
|
|||||||
|
|
||||||
n_frames = x.size(1)
|
n_frames = x.size(1)
|
||||||
max_pos = n_frames
|
max_pos = n_frames
|
||||||
# max_pos = self.positional_embedding.size(0)
|
|
||||||
# max_pos = n_frames if n_frames < max_pos else max_pos
|
|
||||||
# x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
|
|
||||||
|
|
||||||
if ilens is not None:
|
if ilens is not None:
|
||||||
if self.downsample_rate == 4:
|
if self.downsample_rate == 4:
|
||||||
@ -1589,8 +1585,13 @@ class SenseVoiceEncoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
padding_mask = None
|
padding_mask = None
|
||||||
|
|
||||||
|
device = x.device
|
||||||
|
seq_length = x.shape[1]
|
||||||
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
|
||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
x = block(x, mask=padding_mask, is_pad_mask=True)
|
x = block(x, mask=padding_mask, position_ids=position_ids)
|
||||||
|
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
|
|
||||||
@ -1625,7 +1626,6 @@ class SenseVoiceL(nn.Module):
|
|||||||
encoder_conf = kwargs.get("encoder_conf", {})
|
encoder_conf = kwargs.get("encoder_conf", {})
|
||||||
encoder_class = tables.encoder_classes.get(encoder)
|
encoder_class = tables.encoder_classes.get(encoder)
|
||||||
encoder = encoder_class(**encoder_conf)
|
encoder = encoder_class(**encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
|
||||||
|
|
||||||
dims = kwargs.get("dims", {})
|
dims = kwargs.get("dims", {})
|
||||||
dims = whisper.model.ModelDimensions(**dims)
|
dims = whisper.model.ModelDimensions(**dims)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user