diff --git a/funasr/bin/train.py b/funasr/bin/train.py index ab49c822c..c02a66ff7 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -176,15 +176,12 @@ def main(**kwargs): except: writer = None - # if use_ddp or use_fsdp: - # context = Join([model]) - # else: - # context = nullcontext() - context = nullcontext() + for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): time1 = time.perf_counter() - with context: - dataloader_tr, dataloader_val = dataloader.build_iter(epoch) + + for data_split_i in range(dataloader.data_split_num): + dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i) trainer.train_epoch( model=model, optim=optim, @@ -193,15 +190,17 @@ def main(**kwargs): dataloader_train=dataloader_tr, dataloader_val=dataloader_val, epoch=epoch, - writer=writer + writer=writer, + data_split_i=data_split_i, + data_split_num=dataloader.data_split_num, ) - with context: - trainer.validate_epoch( - model=model, - dataloader_val=dataloader_val, - epoch=epoch, - writer=writer - ) + + trainer.validate_epoch( + model=model, + dataloader_val=dataloader_val, + epoch=epoch, + writer=writer + ) scheduler.step() diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index 3270531f5..de0d653ae 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -2,8 +2,9 @@ import os import json import torch import logging -import concurrent.futures + import librosa +import random import torch.distributed as dist from funasr.register import tables @@ -44,7 +45,7 @@ from funasr.register import tables # except: # rank = 0 # world_size = 1 -# logging.warning("distributed is not initialized, only single shard") +# logging.info("distributed is not initialized, only single shard") # num_per_rank = total_num // world_size # # # rank = 0 @@ -72,6 +73,7 @@ from funasr.register import tables @tables.register("index_ds_classes", "IndexDSJsonl") @tables.register("index_ds_classes", "IndexDSJsonlRankFull") +@tables.register("index_ds_classes", "IndexDSJsonlRankSplit") class IndexDSJsonlRankFull(torch.utils.data.Dataset): def __init__(self, path: str, **kwargs): @@ -80,83 +82,27 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset): self.min_source_length = kwargs.get("min_source_length", 0) self.max_target_length = kwargs.get("max_target_length", 2048) self.min_target_length = kwargs.get("min_target_length", 0) - if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans - from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list - jsonl_outdir = os.path.dirname(path[0]) - jsonl_name = "datalist_train.jsonl" if kwargs.get("is_training", True) else "datalist_val.jsonl" - jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name) - if not os.path.exists(jsonl_file_out): - print(f"datalist is: {path}, generate jsonl from it") - gen_jsonl_from_wav_text_list(path, jsonl_file_out=jsonl_file_out, **kwargs) - path = jsonl_file_out - contents = [] - with open(path, encoding='utf-8') as fin: - for line in fin: - data = json.loads(line.strip()) - if "text" in data: # for sft - contents.append(data['text']) - if "source" in data: # for speech lab pretrain - prompt = data.get("prompt", "") - source = data["source"] - target = data["target"] - source_len = data.get("source_len", 1) - target_len = data.get("target_len", 0) - if "aishell" in source: - target = target.replace(" ", "") - if source_len < self.min_source_length or source_len > self.max_source_length: - continue - if target_len < self.min_target_length or target_len > self.max_target_length: - continue - contents_i = {"source": source, - "prompt": prompt, - "target": target, - "source_len": source_len, - "target_len": target_len, - } - text_language = data.get("text_language", None) - if text_language is not None: - contents_i["text_language"] = text_language - audio_language = data.get("audio_language", None) - if audio_language is not None: - contents_i["audio_language"] = audio_language - contents.append(contents_i) - - self.contents = contents + is_training = kwargs.get("is_training", True) + if not (path.endswith(".jsonl") or path.endswith(".json")): + # jsonl list file + data_split_num = kwargs.get("data_split_num", 1) + data_split_i = kwargs.get("data_split_i", 0) + + if not is_training: + data_split_num = 1 + data_split_i = 0 + with open(path, encoding='utf-8') as fin: + file_list_all = fin.readlines() + + num_per_slice = len(file_list_all) // data_split_num + file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice] + logging.info( + f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}") - logging.info( - "total_num of samplers across ranks: {}".format(len(self.contents))) - - def __len__(self): - return len(self.contents) - - def __getitem__(self, index): - try: - data = self.contents[index] - except: - print(index) - return data - - def get_source_len(self, data_dict): - return data_dict.get("source_len", 1) - - def get_target_len(self, data_dict): - - return data_dict.get("target_len", 0) - - -@tables.register("index_ds_classes", "IndexDSJsonlRankSplit") -class IndexDSJsonlRankSplit(torch.utils.data.Dataset): - - def __init__(self, path: str, **kwargs): - super().__init__() - self.max_source_length = kwargs.get("max_source_length", 2048) - self.min_source_length = kwargs.get("min_source_length", 0) - self.max_target_length = kwargs.get("max_target_length", 2048) - self.min_target_length = kwargs.get("min_target_length", 0) - - with open(path, encoding='utf-8') as fin: - file_list = fin.readlines() + else: + file_list = [path] + total_num = len(file_list) try: @@ -165,16 +111,30 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset): except: rank = 0 world_size = 1 - logging.warning("distributed is not initialized, only single shard") + logging.info("distributed is not initialized, only single shard") + + if not kwargs.get("rank_split", False): + logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global") + rank = 0 + world_size = 1 + num_per_rank = total_num // world_size if num_per_rank * world_size < total_num: - logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") + logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") + total_num_needed = num_per_rank * world_size + + extra_num = total_num_needed - total_num + file_list_tmp = random.choices(file_list, k=extra_num) + file_list += file_list_tmp + logging.info(f"Warning, after random choices: {file_list}") file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] + logging.info( + f"is_training: {is_training}, file_list_rank: {file_list_rank}") + contents = [] for file_json in file_list_rank: - with open(file_json.strip(), encoding='utf-8') as fin: for line in fin: data = json.loads(line.strip()) @@ -182,41 +142,42 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset): contents.append(data['text']) if "source" in data: # for speech lab pretrain prompt = data.get("prompt", "") - source = data["source"].replace("/cpfs01", "/cpfs_speech/data") + source = data["source"].replace("/cpfs01", "/cpfs_speech/data") # only use in alibaba gpu group: .replace("/cpfs01", "/cpfs_speech/data") target = data["target"] source_len = data.get("source_len", 1) target_len = data.get("target_len", 0) - + if "aishell" in source: + target = target.replace(" ", "") if source_len < self.min_source_length or source_len > self.max_source_length: continue if target_len < self.min_target_length or target_len > self.max_target_length: continue contents_i = {"source": source, - "prompt": prompt, - "target": target, - "source_len": source_len, - "target_len": target_len, - } + "prompt": prompt, + "target": target, + "source_len": source_len, + "target_len": target_len, + } text_language = data.get("text_language", None) if text_language is not None: contents_i["text_language"] = text_language - audio_language = data.get("audio_language", None) - if audio_language is not None: - contents_i["audio_language"] = audio_language + # audio_language = data.get("audio_language", None) + # if audio_language is not None: + # contents_i["audio_language"] = audio_language contents.append(contents_i) - + self.contents = contents - logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}") + logging.info( + "total_num of samplers: {}, {}".format(len(self.contents), path)) def __len__(self): return len(self.contents) def __getitem__(self, index): - try: - data = self.contents[index] - except: - print(index) + + data = self.contents[index] + return data def get_source_len(self, data_dict): @@ -225,3 +186,96 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset): def get_target_len(self, data_dict): return data_dict.get("target_len", 0) + +# +# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") +# class IndexDSJsonlRankSplit(torch.utils.data.Dataset): +# +# def __init__(self, path: str, **kwargs): +# super().__init__() +# logging.info("building IndexDS") +# self.max_source_length = kwargs.get("max_source_length", 2048) +# self.min_source_length = kwargs.get("min_source_length", 0) +# self.max_target_length = kwargs.get("max_target_length", 2048) +# self.min_target_length = kwargs.get("min_target_length", 0) +# +# data_split_num = kwargs.get("data_split_num", 1) +# data_split_i = kwargs.get("data_split_i", 0) +# if not kwargs.get("is_training", True): +# data_split_num = 1 +# data_split_i = 0 +# with open(path, encoding='utf-8') as fin: +# file_list_all = fin.readlines() +# +# num_per_slice = len(file_list_all) // data_split_num +# file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice] +# logging.info(f"data_split_num: {data_split_num}, data_split_i: {data_split_i}, file_list: {file_list}, file_list_all: {file_list_all}") +# +# +# total_num = len(file_list) +# try: +# rank = dist.get_rank() +# world_size = dist.get_world_size() +# except: +# rank = 0 +# world_size = 1 +# logging.info("distributed is not initialized, only single shard") +# num_per_rank = total_num // world_size +# if num_per_rank * world_size < total_num: +# logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") +# +# file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] +# +# contents = [] +# for file_json in file_list_rank: +# +# with open(file_json.strip(), encoding='utf-8') as fin: +# for line in fin: +# data = json.loads(line.strip()) +# if "text" in data: # for sft +# contents.append(data['text']) +# if "source" in data: # for speech lab pretrain +# prompt = data.get("prompt", "") +# source = data["source"].replace("/cpfs01", "/cpfs_speech/data") +# target = data["target"] +# source_len = data.get("source_len", 1) +# target_len = data.get("target_len", 0) +# +# if source_len < self.min_source_length or source_len > self.max_source_length: +# continue +# if target_len < self.min_target_length or target_len > self.max_target_length: +# continue +# contents_i = {"source": source, +# "prompt": prompt, +# "target": target, +# "source_len": source_len, +# "target_len": target_len, +# } +# text_language = data.get("text_language", None) +# if text_language is not None: +# contents_i["text_language"] = text_language +# # audio_language = data.get("audio_language", None) +# # if audio_language is not None: +# # contents_i["audio_language"] = audio_language +# contents.append(contents_i) +# +# self.contents = contents +# +# logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}, file_list_rank: {file_list_rank}") +# +# def __len__(self): +# return len(self.contents) +# +# def __getitem__(self, index): +# try: +# data = self.contents[index] +# except: +# print(index) +# return data +# +# def get_source_len(self, data_dict): +# return data_dict.get("source_len", 1) +# +# def get_target_len(self, data_dict): +# +# return data_dict.get("target_len", 0) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 108e68a48..fdf630ee1 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -301,6 +301,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler): batch_type="token", num_replicas=None, rank=None, + rank_split=False, shuffle=True, drop_last=False, is_training: bool = True, @@ -314,6 +315,12 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler): except: rank = 0 num_replicas = 1 + + if rank_split: + logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank") + rank = 0 + num_replicas = 1 + self.rank = rank self.num_replicas = num_replicas self.dataset = dataset diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py index abb28285a..70da722aa 100644 --- a/funasr/datasets/dataloader_entry.py +++ b/funasr/datasets/dataloader_entry.py @@ -40,7 +40,21 @@ class DataloaderMapStyle: self.dataset_val = dataset_val self.kwargs = kwargs - def build_iter(self, epoch=0): + # split dataset + self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1) + self.dataset_class = dataset_class + self.frontend = frontend + self.tokenizer = tokenizer + self.kwargs = kwargs + + def build_iter(self, epoch=0, data_split_i=0, **kwargs): + + # reload dataset slice + if self.data_split_num > 1: + del self.dataset_tr + self.dataset_tr = self.dataset_class(self.kwargs.get("train_data_set_list"), frontend=self.frontend, tokenizer=self.tokenizer, + is_training=True, **self.kwargs.get("dataset_conf"), data_split_i=data_split_i) + # dataloader batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler") batch_sampler_val = None diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py index 9087ea153..9fdb3bdf6 100644 --- a/funasr/models/sense_voice/decoder.py +++ b/funasr/models/sense_voice/decoder.py @@ -245,29 +245,7 @@ class SenseVoiceDecoder(nn.Module): self.register_buffer("mask", mask, persistent=False) self.use_padmask = kwargs.get("use_padmask", True) - # def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): - # """ - # x : torch.LongTensor, shape = (batch_size, <= n_ctx) - # the text tokens - # xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) - # the encoded audio features to be attended on - # """ - # offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - # x = ( - # self.token_embedding(x) - # + self.positional_embedding[offset: offset + x.shape[-1]] - # ) - # x = x.to(xa.dtype) - # - # for block in self.blocks: - # x = block(x, xa, mask=self.mask, kv_cache=kv_cache) - # - # x = self.ln(x) - # logits = ( - # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - # ).float() - # - # return logits + def forward( diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index caaef38a6..3ee6885a0 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -252,6 +252,7 @@ class Trainer: dataloader_val=None, epoch=None, writer=None, + **kwargs, ): """ Defines the training process for a single epoch with gradient accumulation. @@ -374,6 +375,8 @@ class Trainer: stats=stats, writer=writer, tag="train", + data_split_i=kwargs.get("data_split_i", 0), + data_split_num=kwargs.get("data_split_num", 1), ) if (batch_idx + 1) % self.validate_interval == 0: @@ -507,6 +510,9 @@ class Trainer: stats=None, writer=None, tag="train", + data_split_i=0, + data_split_num=1, + **kwargs, ): if (batch_idx + 1) % self.log_interval == 0: @@ -526,6 +532,7 @@ class Trainer: f"{tag}, " f"rank: {self.local_rank}, " f"epoch: {epoch}/{self.max_epoch}, " + f"data_slice: {data_split_i}/{data_split_num}, " f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " f"(loss_avg_rank: {loss:.3f}), " f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "