mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix bug
This commit is contained in:
parent
2a8d041806
commit
1c8b46a233
@ -158,6 +158,8 @@ def main(**kwargs):
|
||||
time1 = time.perf_counter()
|
||||
|
||||
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
|
||||
time_slice_i = time.perf_counter()
|
||||
|
||||
dataloader_tr, dataloader_val = dataloader.build_iter(
|
||||
epoch, data_split_i=data_split_i, start_step=trainer.start_step
|
||||
)
|
||||
@ -178,6 +180,14 @@ def main(**kwargs):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
time_escaped = (time.perf_counter() - time1) / 3600.0
|
||||
logging.info(
|
||||
f"rank: {local_rank}, "
|
||||
f"time_escaped_epoch: {time_escaped:.3f} hours, "
|
||||
f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours"
|
||||
f"epoch: {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
|
||||
)
|
||||
|
||||
trainer.start_data_split_i = 0
|
||||
trainer.validate_epoch(model=model, dataloader_val=dataloader_val, epoch=epoch + 1)
|
||||
scheduler.step()
|
||||
|
||||
@ -334,6 +334,7 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
|
||||
drop_last=False,
|
||||
is_training: bool = True,
|
||||
sort_size: int = 1024,
|
||||
start_step: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@ -364,12 +365,14 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
|
||||
self.sort_size = sort_size * num_replicas
|
||||
self.max_token_length = kwargs.get("max_token_length", 2048)
|
||||
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
||||
self.start_step = kwargs.get("start_step", 2048)
|
||||
self.batch_size_sample_max = kwargs.get("batch_size_sample_max", 200)
|
||||
|
||||
super().__init__(
|
||||
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last
|
||||
)
|
||||
self.start_step = start_step
|
||||
self.batch_num = 1
|
||||
if self.start_step > 0:
|
||||
logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
|
||||
# super().__init__(
|
||||
# dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last
|
||||
# )
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
@ -424,11 +427,11 @@ class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
|
||||
rank_batches[i % self.num_replicas].append(batch)
|
||||
|
||||
# Assign all batches for the current rank directly
|
||||
final_batches = rank_batches[self.rank] # [self.start_step :]
|
||||
final_batches = rank_batches[self.rank][self.start_step :]
|
||||
self.batch_num = len(final_batches)
|
||||
|
||||
logging.info(
|
||||
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {self.batch_num}"
|
||||
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {rank_batches[self.rank]}, after: {self.batch_num}"
|
||||
)
|
||||
return iter(final_batches)
|
||||
|
||||
|
||||
@ -49,14 +49,19 @@ class DataloaderMapStyle:
|
||||
def __init__(self, frontend=None, tokenizer=None, **kwargs):
|
||||
# dataset
|
||||
logging.info("Build dataloader")
|
||||
|
||||
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
||||
dataset_tr = dataset_class(
|
||||
kwargs.get("train_data_set_list"),
|
||||
frontend=frontend,
|
||||
tokenizer=tokenizer,
|
||||
is_training=True,
|
||||
**kwargs.get("dataset_conf"),
|
||||
)
|
||||
dataset_tr = None
|
||||
# split dataset
|
||||
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
|
||||
if self.data_split_num == 1:
|
||||
dataset_tr = dataset_class(
|
||||
kwargs.get("train_data_set_list"),
|
||||
frontend=frontend,
|
||||
tokenizer=tokenizer,
|
||||
is_training=True,
|
||||
**kwargs.get("dataset_conf"),
|
||||
)
|
||||
dataset_val = dataset_class(
|
||||
kwargs.get("valid_data_set_list"),
|
||||
frontend=frontend,
|
||||
@ -69,8 +74,6 @@ class DataloaderMapStyle:
|
||||
self.dataset_val = dataset_val
|
||||
self.kwargs = kwargs
|
||||
|
||||
# split dataset
|
||||
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
|
||||
self.dataset_class = dataset_class
|
||||
self.frontend = frontend
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@ -167,6 +167,8 @@ class Trainer:
|
||||
Args:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
step_in_epoch = None if step is None else step_in_epoch
|
||||
if self.use_deepspeed:
|
||||
|
||||
@ -760,6 +762,10 @@ class Trainer:
|
||||
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
|
||||
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
|
||||
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
|
||||
|
||||
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
|
||||
dist.barrier()
|
||||
|
||||
model.train()
|
||||
|
||||
def log(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user