update cam++ for embed extract

This commit is contained in:
shixian.shi 2024-01-10 19:10:26 +08:00
parent d342c642fa
commit 668b830cb2
8 changed files with 689 additions and 87 deletions

View File

@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)

View File

@ -159,6 +159,9 @@ class AutoModel:
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
vocab_size = len(tokenizer.token_list)
else:
vocab_size = -1
# build frontend
frontend = kwargs.get("frontend", None)
@ -170,8 +173,7 @@ class AutoModel:
# build model
model_class = tables.model_classes.get(kwargs["model"].lower())
model = model_class(**kwargs, **kwargs["model_conf"],
vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)

View File

@ -1 +0,0 @@
from .campplus import CAMPPlus

View File

@ -7,6 +7,82 @@ import torch.utils.checkpoint as cp
from torch import nn
class BasicResBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class FCM(nn.Module):
def __init__(self,
block=BasicResBlock,
num_blocks=[2, 2],
m_channels=32,
feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
shape = out.shape
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
return out
def get_nonlinear(config_str, channels):
nonlinear = nn.Sequential()
for name in config_str.split('-'):
@ -216,39 +292,3 @@ class DenseLayer(nn.Module):
return x
class BasicResBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes,
planes,
kernel_size=3,
stride=(stride, 1),
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=(stride, 1),
bias=False),
nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out

View File

@ -1,54 +1,24 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import os
import time
import torch
import logging
import numpy as np
import torch.nn as nn
from collections import OrderedDict
from typing import Union, Dict, List, Tuple, Optional
import torch.nn.functional as F
from torch import nn
from funasr.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
BasicResBlock, get_nonlinear
class FCM(nn.Module):
def __init__(self,
block=BasicResBlock,
num_blocks=[2, 2],
m_channels=32,
feat_dim=80):
super(FCM, self).__init__()
self.in_planes = m_channels
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(m_channels)
self.out_channels = m_channels * (feat_dim // 8)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = F.relu(self.bn2(self.conv2(out)))
shape = out.shape
out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
return out
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
BasicResBlock, get_nonlinear, FCM
from funasr.models.campplus.utils import extract_feature
@tables.register("model_classes", "CAMPPlus")
class CAMPPlus(nn.Module):
def __init__(self,
feat_dim=80,
@ -58,8 +28,9 @@ class CAMPPlus(nn.Module):
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True,
output_level='segment'):
super(CAMPPlus, self).__init__()
output_level='segment',
**kwargs,):
super().__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels
@ -123,3 +94,28 @@ class CAMPPlus(nn.Module):
if self.output_level == 'frame':
x = x.transpose(1, 2)
return x
def generate(self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
# extract fbank feats
meta_data = {}
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_feature(audio_sample_list)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
# import pdb; pdb.set_trace()
results = []
embeddings = self.forward(speech)
for embedding in embeddings:
results.append({"spk_embedding":embedding})
return results, meta_data

View File

@ -0,0 +1,23 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: CAMPPlus
model_conf:
feat_dim: 80
embedding_size: 192
growth_rate: 32
bn_size: 4
init_channels: 128
config_str: 'batchnorm-relu'
memory_efficient: True
output_level: 'segment'
# frontend related
frontend: WavFrontend
frontend_conf:
fs: 16000

View File

@ -0,0 +1,533 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
import contextlib
import os
import tempfile
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Generator, Union
import requests
def check_audio_list(audio: list):
audio_dur = 0
for i in range(len(audio)):
seg = audio[i]
assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
assert int(seg[1] * 16000) - int(
seg[0] * 16000
) == seg[2].shape[
0], 'modelscope error: audio data in list is inconsistent with time length.'
if i > 0:
assert seg[0] >= audio[
i - 1][1], 'modelscope error: Wrong time stamps.'
audio_dur += seg[1] - seg[0]
return audio_dur
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
def sv_preprocess(inputs: Union[np.ndarray, list]):
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(
inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
)
output.append(data)
return output
def sv_chunk(vad_segments: list, fs = 16000) -> list:
config = {
'seg_dur': 1.5,
'seg_shift': 0.75,
}
def seg_chunk(seg_data):
seg_st = seg_data[0]
data = seg_data[2]
chunk_len = int(config['seg_dur'] * fs)
chunk_shift = int(config['seg_shift'] * fs)
last_chunk_ed = 0
seg_res = []
for chunk_st in range(0, data.shape[0], chunk_shift):
chunk_ed = min(chunk_st + chunk_len, data.shape[0])
if chunk_ed <= last_chunk_ed:
break
last_chunk_ed = chunk_ed
chunk_st = max(0, chunk_ed - chunk_len)
chunk_data = data[chunk_st:chunk_ed]
if chunk_data.shape[0] < chunk_len:
chunk_data = np.pad(chunk_data,
(0, chunk_len - chunk_data.shape[0]),
'constant')
seg_res.append([
chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
chunk_data
])
return seg_res
segs = []
for i, s in enumerate(vad_segments):
segs.extend(seg_chunk(s))
return segs
def extract_feature(audio):
features = []
feature_lengths = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
feature_lengths.append(au.shape[0])
features = torch.cat(features)
return features, feature_lengths
def postprocess(segments: list, vad_segments: list,
labels: np.ndarray, embeddings: np.ndarray) -> list:
assert len(segments) == len(labels)
labels = correct_labels(labels)
distribute_res = []
for i in range(len(segments)):
distribute_res.append([segments[i][0], segments[i][1], labels[i]])
# merge the same speakers chronologically
distribute_res = merge_seque(distribute_res)
# accquire speaker center
spk_embs = []
for i in range(labels.max() + 1):
spk_emb = embeddings[labels == i].mean(0)
spk_embs.append(spk_emb)
spk_embs = np.stack(spk_embs)
def is_overlapped(t1, t2):
if t1 > t2 + 1e-4:
return True
return False
# distribute the overlap region
for i in range(1, len(distribute_res)):
if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
distribute_res[i][0] = p
distribute_res[i - 1][1] = p
# smooth the result
distribute_res = smooth(distribute_res)
return distribute_res
def correct_labels(labels):
labels_id = 0
id2id = {}
new_labels = []
for i in labels:
if i not in id2id:
id2id[i] = labels_id
labels_id += 1
new_labels.append(id2id[i])
return np.array(new_labels)
def merge_seque(distribute_res):
res = [distribute_res[0]]
for i in range(1, len(distribute_res)):
if distribute_res[i][2] != res[-1][2] or distribute_res[i][
0] > res[-1][1]:
res.append(distribute_res[i])
else:
res[-1][1] = distribute_res[i][1]
return res
def smooth(res, mindur=1):
# short segments are assigned to nearest speakers.
for i in range(len(res)):
res[i][0] = round(res[i][0], 2)
res[i][1] = round(res[i][1], 2)
if res[i][1] - res[i][0] < mindur:
if i == 0:
res[i][2] = res[i + 1][2]
elif i == len(res) - 1:
res[i][2] = res[i - 1][2]
elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
res[i][2] = res[i - 1][2]
else:
res[i][2] = res[i + 1][2]
# merge the speakers
res = merge_seque(res)
return res
def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = []
for d in sentence_list:
sentence_start = d['ts_list'][0][0]
sentence_end = d['ts_list'][-1][1]
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
spk_st, spk_ed, spk = sd_time
spk_st = spk_st*1000
spk_ed = spk_ed*1000
overlap = max(
min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
if overlap > max_overlap:
max_overlap = overlap
sentence_spk = spk
d['spk'] = sentence_spk
sd_sentence_list.append(d)
return sd_sentence_list
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.
All backends need to implement two apis: ``read()`` and ``read_text()``.
``read()`` reads the file as a byte stream and ``read_text()`` reads
the file as texts.
"""
@abstractmethod
def read(self, filepath: str):
pass
@abstractmethod
def read_text(self, filepath: str):
pass
@abstractmethod
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
pass
@abstractmethod
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
pass
class LocalStorage(Storage):
"""Local hard disk storage"""
def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, 'rb') as f:
content = f.read()
return content
def read_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, 'r', encoding=encoding) as f:
value_buf = f.read()
return value_buf
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, 'wb') as f:
f.write(obj)
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)
@contextlib.contextmanager
def as_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
class HTTPStorage(Storage):
"""HTTP and HTTPS storage."""
def read(self, url):
# TODO @wenmeng.zwm add progress bar if file is too large
r = requests.get(url)
r.raise_for_status()
return r.content
def read_text(self, url):
r = requests.get(url)
r.raise_for_status()
return r.text
@contextlib.contextmanager
def as_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def write(self, obj: bytes, url: Union[str, Path]) -> None:
raise NotImplementedError('write is not supported by HTTP Storage')
def write_text(self,
obj: str,
url: Union[str, Path],
encoding: str = 'utf-8') -> None:
raise NotImplementedError(
'write_text is not supported by HTTP Storage')
class OSSStorage(Storage):
"""OSS storage."""
def __init__(self, oss_config_file=None):
# read from config file or env var
raise NotImplementedError(
'OSSStorage.__init__ to be implemented in the future')
def read(self, filepath):
raise NotImplementedError(
'OSSStorage.read to be implemented in the future')
def read_text(self, filepath, encoding='utf-8'):
raise NotImplementedError(
'OSSStorage.read_text to be implemented in the future')
@contextlib.contextmanager
def as_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
raise NotImplementedError(
'OSSStorage.write to be implemented in the future')
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
raise NotImplementedError(
'OSSStorage.write_text to be implemented in the future')
G_STORAGES = {}
class File(object):
_prefix_to_storage: dict = {
'oss': OSSStorage,
'http': HTTPStorage,
'https': HTTPStorage,
'local': LocalStorage,
}
@staticmethod
def _get_storage(uri):
assert isinstance(uri,
str), f'uri should be str type, but got {type(uri)}'
if '://' not in uri:
# local path
storage_type = 'local'
else:
prefix, _ = uri.split('://')
storage_type = prefix
assert storage_type in File._prefix_to_storage, \
f'Unsupported uri {uri}, valid prefixs: '\
f'{list(File._prefix_to_storage.keys())}'
if storage_type not in G_STORAGES:
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
return G_STORAGES[storage_type]
@staticmethod
def read(uri: str) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
storage = File._get_storage(uri)
return storage.read(uri)
@staticmethod
def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
storage = File._get_storage(uri)
return storage.read_text(uri)
@staticmethod
def write(obj: bytes, uri: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
storage = File._get_storage(uri)
return storage.write(obj, uri)
@staticmethod
def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
storage = File._get_storage(uri)
return storage.write_text(obj, uri)
@contextlib.contextmanager
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
storage = File._get_storage(uri)
with storage.as_local_path(uri) as local_path:
yield local_path

View File

@ -447,7 +447,6 @@ class Paraformer(nn.Module):
frontend=None,
**kwargs,
):
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
@ -475,7 +474,6 @@ class Paraformer(nn.Module):
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):