FunASR/funasr/models/whisper_lid/lid_predictor.py
jmwang66 2acd24f015
update whisper lid (#1407)
* update whisper lid
2024-02-29 17:14:59 +08:00

25 lines
946 B
Python

from funasr.register import tables
from funasr.models.whisper_lid.eres2net.ResNet import ERes2Net, BasicBlockERes2Net, BasicBlockERes2Net_diff_AFF
@tables.register("lid_predictor_classes", "LidPredictor")
class LidPredictor(ERes2Net):
def __init__(self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(LidPredictor, self).__init__(
block=block,
block_fuse=block_fuse,
num_blocks=num_blocks,
m_channels=m_channels,
feat_dim=feat_dim,
embedding_size=embedding_size,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer
)