FunASR/funasr/models/mossformer/mossformer_decoder.py
zhifu gao 861147c730
Dev gzf exp (#1654)
* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* whisper

* whisper

* update style

* update style

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
2024-04-24 16:03:38 +08:00

51 lines
1.3 KiB
Python

import torch
import torch.nn as nn
class MossFormerDecoder(nn.ConvTranspose1d):
"""A decoder layer that consists of ConvTranspose1d.
Arguments
---------
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example
---------
>>> x = torch.randn(2, 100, 1000)
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
>>> h = decoder(x)
>>> h.shape
torch.Size([2, 1003])
"""
def __init__(self, *args, **kwargs):
super(MossFormerDecoder, self).__init__(*args, **kwargs)
def forward(self, x):
"""Return the decoded output.
Arguments
---------
x : torch.Tensor
Input tensor with dimensionality [B, N, L].
where, B = Batchsize,
N = number of filters
L = time points
"""
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x