mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
modify the qformer adaptor (#1804)
Co-authored-by: nichongjia-2007 <nichongjia@gmail.com>
This commit is contained in:
parent
2175736ab0
commit
24af4286d5
@ -51,18 +51,40 @@ class EncoderProjectorQFormer(nn.Module):
|
||||
|
||||
self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
|
||||
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
|
||||
|
||||
self.second_per_frame = 0.333333
|
||||
self.second_stride = 0.333333
|
||||
|
||||
def forward(self, x, atts):
|
||||
query = self.query.expand(x.shape[0], -1, -1)
|
||||
def split_frames(self, speech_embeds):
|
||||
B, T, C = speech_embeds.shape
|
||||
kernel = round(T * self.second_per_frame / 30.0)
|
||||
stride = round(T * self.second_stride / 30.0)
|
||||
kernel = (1, kernel)
|
||||
stride = (1, stride)
|
||||
speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
|
||||
speech_embeds_overlap = torch.nn.functional.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride)
|
||||
_, _, L = speech_embeds_overlap.shape
|
||||
speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
|
||||
speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
|
||||
speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
|
||||
speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device)
|
||||
return speech_embeds, speech_atts
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
encoder_out_feat, attention_mask = self.split_frames(x)
|
||||
query = self.query.expand(encoder_out_feat.shape[0], -1, -1)
|
||||
|
||||
|
||||
query_output = self.qformer(
|
||||
query_embeds=query,
|
||||
encoder_hidden_states=x,
|
||||
encoder_attention_mask=atts,
|
||||
encoder_hidden_states=encoder_out_feat,
|
||||
encoder_attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
query_proj = self.norm(self.linear(query_output.last_hidden_state))
|
||||
query_proj = query_proj.view(B, -1, query_proj.size(2)).contiguous()
|
||||
|
||||
return query_proj
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user