diff --git a/funasr/bin/asr_trainer.py b/funasr/bin/asr_trainer.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py new file mode 100644 index 000000000..2875d8d41 --- /dev/null +++ b/funasr/datasets/data_sampler.py @@ -0,0 +1,60 @@ +import torch + +class BatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset=None, args=None, drop_last=True, ): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.batch_size_type = args.batch_size_type + self.batch_size = args.batch_size + self.sort_size = args.sort_size + self.max_length_token = args.max_length_token + self.total_samples = len(dataset) + + + def __len__(self): + return self.total_samples + + + def __iter__(self): + batch = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples-1) // self.sort_size + 1 + for iter in range(self.pre_idx + 1, iter_num): + datalen_with_index = [] + for i in range(self.sort_size): + idx = iter * self.sort_size + i + if idx >= self.total_samples: + continue + + if self.batch_size_type == "example": + sample_len_cur = 1 + else: + idx_map = self.dataset.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \ + self.dataset.indexed_dataset[idx_map]["target_len"] + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for item in datalen_with_index_sort: + idx, sample_len_cur = item + if sample_len_cur > self.max_length_token: + continue + max_token_cur = max(max_token, sample_len_cur) + max_token_padding = (1 + num_sample) * max_token_cur + if max_token_padding <= self.batch_size: + batch.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + yield batch + max_token = sample_len_cur + num_sample = 1 + batch = [idx] + \ No newline at end of file diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py new file mode 100644 index 000000000..283fbd976 --- /dev/null +++ b/funasr/datasets/dataset_jsonl.py @@ -0,0 +1,43 @@ +import torch +import json +import torch.distributed as dist + +class AudioDatasetJsonl(torch.utils.data.Dataset): + + def __init__(self, path, data_parallel_rank=0, data_parallel_size=1): + super().__init__() + data_parallel_size = dist.get_world_size() + contents = [] + with open(path, encoding='utf-8') as fin: + for line in fin: + data = json.loads(line.strip()) + if "text" in data: # for sft + self.contents.append(data['text']) + if "source" in data: # for speech lab pretrain + prompt = data["prompt"] + source = data["source"] + target = data["target"] + source_len = data["source_len"] + target_len = data["target_len"] + + contents.append({"source": source, + "prompt": prompt, + "target": target, + "source_len": source_len, + "target_len": target_len, + } + ) + + self.contents = [] + total_num = len(contents) + num_per_rank = total_num // data_parallel_size + rank = dist.get_rank() + # import ipdb; ipdb.set_trace() + self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] + + + def __len__(self): + return len(self.contents) + + def __getitem__(self, index): + return self.contents[index] diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py index fdeb1c576..00a997063 100644 --- a/funasr/models/frontend/s3prl.py +++ b/funasr/models/frontend/s3prl.py @@ -10,7 +10,7 @@ import humanfriendly import torch from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.modules.frontends.frontend import Frontend +from funasr.models.frontend.frontends_utils.frontend import Frontend from funasr.modules.nets_utils import pad_list from funasr.utils.get_default_kwargs import get_default_kwargs