This commit is contained in:
游雁 2024-06-09 00:21:44 +08:00
parent 2a8d041806
commit 1c8b46a233
4 changed files with 38 additions and 16 deletions

View File

@ -158,6 +158,8 @@ def main(**kwargs):
time1 = time.perf_counter()
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
time_slice_i = time.perf_counter()
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
@ -178,6 +180,14 @@ def main(**kwargs):
torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time1) / 3600.0
logging.info(
f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours"
f"epoch: {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
)
trainer.start_data_split_i = 0
trainer.validate_epoch(model=model, dataloader_val=dataloader_val, epoch=epoch + 1)
scheduler.step()

View File

@ -334,6 +334,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
start_step: int = 0,
**kwargs,
):
@ -364,12 +365,14 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
self.sort_size = sort_size * num_replicas
self.max_token_length = kwargs.get("max_token_length", 2048)
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
self.start_step = kwargs.get("start_step", 2048)
self.batch_size_sample_max = kwargs.get("batch_size_sample_max", 200)
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last
)
self.start_step = start_step
self.batch_num = 1
if self.start_step > 0:
logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
# super().__init__(
# dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last
# )
def __iter__(self):
if self.shuffle:
@ -424,11 +427,11 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
rank_batches[i % self.num_replicas].append(batch)
# Assign all batches for the current rank directly
final_batches = rank_batches[self.rank] # [self.start_step :]
final_batches = rank_batches[self.rank][self.start_step :]
self.batch_num = len(final_batches)
logging.info(
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {self.batch_num}"
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {rank_batches[self.rank]}, after: {self.batch_num}"
)
return iter(final_batches)

View File

@ -49,14 +49,19 @@ class DataloaderMapStyle:
def __init__(self, frontend=None, tokenizer=None, **kwargs):
# dataset
logging.info("Build dataloader")
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
**kwargs.get("dataset_conf"),
)
dataset_tr = None
# split dataset
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
if self.data_split_num == 1:
dataset_tr = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
**kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
@ -69,8 +74,6 @@ class DataloaderMapStyle:
self.dataset_val = dataset_val
self.kwargs = kwargs
# split dataset
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
self.dataset_class = dataset_class
self.frontend = frontend
self.tokenizer = tokenizer

View File

@ -167,6 +167,8 @@ class Trainer:
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
step_in_epoch = None if step is None else step_in_epoch
if self.use_deepspeed:
@ -760,6 +762,10 @@ class Trainer:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
model.train()
def log(