FunASR/funasr/models/eend/eend_ola_dataloader.py
zhifu gao 861147c730
Dev gzf exp (#1654)
* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* whisper

* whisper

* update style

* update style

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
2024-04-24 16:03:38 +08:00

58 lines
1.7 KiB
Python

import logging
import kaldiio
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
def custom_collate(batch):
keys, speech, speaker_labels, orders = zip(*batch)
speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders)
return keys, batch
class EENDOLADataset(Dataset):
def __init__(
self,
data_file,
):
self.data_file = data_file
with open(data_file) as f:
lines = f.readlines()
self.samples = [line.strip().split() for line in lines]
logging.info("total samples: {}".format(len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
key, speech_path, speaker_label_path = self.samples[idx]
speech = kaldiio.load_mat(speech_path)
speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
order = np.arange(speech.shape[0])
np.random.shuffle(order)
return key, speech, speaker_label, order
class EENDOLADataLoader:
def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
dataset = EENDOLADataset(data_file)
self.data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=custom_collate,
shuffle=shuffle,
num_workers=num_workers,
)
def build_iter(self, epoch):
return self.data_loader