mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
extract
This commit is contained in:
parent
925227e9f1
commit
20280f5db5
@ -59,7 +59,6 @@ def sense_voice_decode_forward(
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
time0 = time.perf_counter()
|
||||
# import pdb;pdb.set_trace()
|
||||
use_padmask = self.use_padmask
|
||||
hlens = kwargs.get("hlens", None)
|
||||
@ -91,8 +90,7 @@ def sense_voice_decode_forward(
|
||||
|
||||
x = self.ln(x)
|
||||
x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||
time1 = time.perf_counter()
|
||||
print(f"decoder: {time1 - time0:0.3f}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -1559,7 +1559,6 @@ class SenseVoiceEncoder(nn.Module):
|
||||
ilens: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
time0 = time.perf_counter()
|
||||
use_padmask = self.use_padmask
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
@ -1600,8 +1599,7 @@ class SenseVoiceEncoder(nn.Module):
|
||||
x = block(x, mask=padding_mask, position_ids=position_ids)
|
||||
|
||||
x = self.ln_post(x)
|
||||
time1 = time.perf_counter()
|
||||
print(f"encoder: {time1 - time0:0.3f}")
|
||||
|
||||
if ilens is None:
|
||||
return x
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user