From 8795bf5bf1daac5a839f856a748d7e92cc4c5015 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Tue, 23 Apr 2024 19:36:15 +0800 Subject: [PATCH] Dev gzf exp (#1649) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * bugfix * update with main (#1631) * update seaco finetune * v1.0.24 --------- Co-authored-by: 维石 * sensevoice * sensevoice * sensevoice * update with main (#1638) * update seaco finetune * v1.0.24 * update rwkv template --------- Co-authored-by: 维石 * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice --------- Co-authored-by: 维石 --- .../audio_datasets/espnet_samplers.py | 8 +- funasr/datasets/audio_datasets/index_ds.py | 213 +++--------------- funasr/datasets/audio_datasets/samplers.py | 8 +- 3 files changed, 39 insertions(+), 190 deletions(-) diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py index 4bb34f34d..6b38bc217 100644 --- a/funasr/datasets/audio_datasets/espnet_samplers.py +++ b/funasr/datasets/audio_datasets/espnet_samplers.py @@ -48,10 +48,10 @@ class EspnetStyleBatchSampler(DistributedSampler): except: rank = 0 num_replicas = 1 - if rank_split: - logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank") - rank = 0 - num_replicas = 1 + # if rank_split: + # logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank") + # rank = 0 + # num_replicas = 1 self.rank = rank self.num_replicas = num_replicas self.dataset = dataset diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index de0d653ae..06bd4de16 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -10,67 +10,6 @@ import torch.distributed as dist from funasr.register import tables -# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") -# class IndexDSJsonlRankSplit(torch.utils.data.Dataset): -# -# def __init__(self, path): -# super().__init__() -# -# contents = [] -# with open(path, encoding='utf-8') as fin: -# for line in fin: -# data = json.loads(line.strip()) -# if "text" in data: # for sft -# self.contents.append(data['text']) -# if "source" in data: # for speech lab pretrain -# prompt = data["prompt"] -# source = data["source"] -# target = data["target"] -# source_len = data["source_len"] -# target_len = data["target_len"] -# -# contents.append({"source": source, -# "prompt": prompt, -# "target": target, -# "source_len": source_len, -# "target_len": target_len, -# } -# ) -# -# self.contents = [] -# total_num = len(contents) -# try: -# rank = dist.get_rank() -# world_size = dist.get_world_size() -# except: -# rank = 0 -# world_size = 1 -# logging.info("distributed is not initialized, only single shard") -# num_per_rank = total_num // world_size -# -# # rank = 0 -# # import ipdb; ipdb.set_trace() -# self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] -# -# logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents))) -# -# def __len__(self): -# return len(self.contents) -# -# def __getitem__(self, index): -# try: -# data = self.contents[index] -# except: -# print(index) -# return data -# -# def get_source_len(self, data_dict): -# return data_dict["source_len"] -# -# def get_target_len(self, data_dict): -# -# return data_dict["target_len"] if "target_len" in data_dict else 0 - @tables.register("index_ds_classes", "IndexDSJsonl") @tables.register("index_ds_classes", "IndexDSJsonlRankFull") @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") @@ -104,37 +43,39 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset): file_list = [path] - total_num = len(file_list) - try: - rank = dist.get_rank() - world_size = dist.get_world_size() - except: - rank = 0 - world_size = 1 - logging.info("distributed is not initialized, only single shard") - - if not kwargs.get("rank_split", False): - logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global") - rank = 0 - world_size = 1 - - num_per_rank = total_num // world_size - if num_per_rank * world_size < total_num: - logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") - total_num_needed = num_per_rank * world_size - - extra_num = total_num_needed - total_num - file_list_tmp = random.choices(file_list, k=extra_num) - file_list += file_list_tmp - logging.info(f"Warning, after random choices: {file_list}") - - file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] - - logging.info( - f"is_training: {is_training}, file_list_rank: {file_list_rank}") + # total_num = len(file_list) + # try: + # rank = dist.get_rank() + # world_size = dist.get_world_size() + # except: + # rank = 0 + # world_size = 1 + # logging.info("distributed is not initialized, only single shard") + # + # if not kwargs.get("rank_split", False): + # logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global") + # rank = 0 + # world_size = 1 + # + # num_per_rank = total_num // world_size + # if num_per_rank * world_size < total_num: + # logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") + # total_num_needed = num_per_rank * world_size + # + # extra_num = total_num_needed - total_num + # file_list_tmp = random.choices(file_list, k=extra_num) + # file_list += file_list_tmp + # logging.info(f"Warning, after random choices: {file_list}") + # + # file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] + # + # logging.info( + # f"is_training: {is_training}, file_list_rank: {file_list_rank}") + # contents = [] + # for file_json in file_list_rank: contents = [] - for file_json in file_list_rank: + for file_json in file_list: with open(file_json.strip(), encoding='utf-8') as fin: for line in fin: data = json.loads(line.strip()) @@ -187,95 +128,3 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset): return data_dict.get("target_len", 0) -# -# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") -# class IndexDSJsonlRankSplit(torch.utils.data.Dataset): -# -# def __init__(self, path: str, **kwargs): -# super().__init__() -# logging.info("building IndexDS") -# self.max_source_length = kwargs.get("max_source_length", 2048) -# self.min_source_length = kwargs.get("min_source_length", 0) -# self.max_target_length = kwargs.get("max_target_length", 2048) -# self.min_target_length = kwargs.get("min_target_length", 0) -# -# data_split_num = kwargs.get("data_split_num", 1) -# data_split_i = kwargs.get("data_split_i", 0) -# if not kwargs.get("is_training", True): -# data_split_num = 1 -# data_split_i = 0 -# with open(path, encoding='utf-8') as fin: -# file_list_all = fin.readlines() -# -# num_per_slice = len(file_list_all) // data_split_num -# file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice] -# logging.info(f"data_split_num: {data_split_num}, data_split_i: {data_split_i}, file_list: {file_list}, file_list_all: {file_list_all}") -# -# -# total_num = len(file_list) -# try: -# rank = dist.get_rank() -# world_size = dist.get_world_size() -# except: -# rank = 0 -# world_size = 1 -# logging.info("distributed is not initialized, only single shard") -# num_per_rank = total_num // world_size -# if num_per_rank * world_size < total_num: -# logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") -# -# file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] -# -# contents = [] -# for file_json in file_list_rank: -# -# with open(file_json.strip(), encoding='utf-8') as fin: -# for line in fin: -# data = json.loads(line.strip()) -# if "text" in data: # for sft -# contents.append(data['text']) -# if "source" in data: # for speech lab pretrain -# prompt = data.get("prompt", "") -# source = data["source"].replace("/cpfs01", "/cpfs_speech/data") -# target = data["target"] -# source_len = data.get("source_len", 1) -# target_len = data.get("target_len", 0) -# -# if source_len < self.min_source_length or source_len > self.max_source_length: -# continue -# if target_len < self.min_target_length or target_len > self.max_target_length: -# continue -# contents_i = {"source": source, -# "prompt": prompt, -# "target": target, -# "source_len": source_len, -# "target_len": target_len, -# } -# text_language = data.get("text_language", None) -# if text_language is not None: -# contents_i["text_language"] = text_language -# # audio_language = data.get("audio_language", None) -# # if audio_language is not None: -# # contents_i["audio_language"] = audio_language -# contents.append(contents_i) -# -# self.contents = contents -# -# logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}, file_list_rank: {file_list_rank}") -# -# def __len__(self): -# return len(self.contents) -# -# def __getitem__(self, index): -# try: -# data = self.contents[index] -# except: -# print(index) -# return data -# -# def get_source_len(self, data_dict): -# return data_dict.get("source_len", 1) -# -# def get_target_len(self, data_dict): -# -# return data_dict.get("target_len", 0) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index fdf630ee1..1394f7e8c 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -316,10 +316,10 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler): rank = 0 num_replicas = 1 - if rank_split: - logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank") - rank = 0 - num_replicas = 1 + # if rank_split: + # logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank") + # rank = 0 + # num_replicas = 1 self.rank = rank self.num_replicas = num_replicas