This commit is contained in:
游雁 2023-12-06 23:50:54 +08:00
parent d6aa93946e
commit 15868f6230
2 changed files with 22 additions and 68 deletions

View File

@ -46,7 +46,7 @@ def main(kwargs: DictConfig):
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ
use_ddp = 'WORLD_SIZE' in os.environ and os.environ["WORLD_SIZE"] > 1
use_fsdp = kwargs.get("use_fsdp", None)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
@ -109,7 +109,8 @@ def main(kwargs: DictConfig):
if use_ddp:
model = model.cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
model = FSDP(model).cuda(local_rank)
else:
@ -157,13 +158,6 @@ def main(kwargs: DictConfig):
torch.distributed.destroy_process_group()
def train(epoch, model, op):
pass
def val():
pass
if __name__ == "__main__":
main()

View File

@ -5,6 +5,7 @@ import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from funasr.torch_utils.recursive_op import recursive_average
class Trainer:
"""
@ -56,6 +57,8 @@ class Trainer:
self.start_epoch = 1
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = torch.device("cuda", local_rank)
@ -113,7 +116,7 @@ class Trainer:
# self._validate_epoch(epoch)
if dist.get_rank() == 0:
self._save_checkpoint(epoch)
# self.scheduler.step()
self.scheduler.step()
def _train_epoch(self, epoch):
"""
@ -126,24 +129,34 @@ class Trainer:
dynamic_ncols=True)
# Set the number of steps for gradient accumulation
accumulation_steps = self.kwargs.get("accumulation_steps", 1)
accum_grad = self.kwargs.get("accum_grad", 1)
# Initialize the gradient accumulation
self.optim.zero_grad()
for batch_idx, batch in enumerate(self.dataloader_train):
batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
retval = self.model(**batch)
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed=True)
# Now weight is summation over all workers
loss /= weight
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accumulation_steps
loss = loss / accum_grad
loss.backward()
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
# Perform gradient clipping if it is set
if self.kwargs.get("grad_clip", None) is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
@ -170,43 +183,6 @@ class Trainer:
f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float()})")
pbar.close()
# def _train_epoch(self, epoch):
# """
# Defines the training process for a single epoch.
# Should be implemented with the actual model training steps.
#
# Args:
# epoch (int): The current epoch number.
# """
# self.model.train()
# pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
# for batch_idx, batch in enumerate(self.dataloader_train):
# batch = to_device(batch, "cpu")
# retval = self.model(**batch)
# loss, stats, weight = retval
# self.optim.zero_grad()
# loss.backward()
#
# # compute the gradient norm to check if it is normal or not
# grad_norm = torch.nn.utils.clip_grad_norm_(
# self.model.parameters(),
# max_norm=self.kwargs.get("grad_clip", 10.0),
# norm_type=self.kwargs.get("grad_clip_type", 2.0),
# )
# if not torch.isfinite(grad_norm):
# logging.warning(
# f"The grad norm is {grad_norm}. Skipping updating the model."
# )
# continue
# self.optim.step()
# self.scheduler.step()
# pbar.update(1)
# pbar.set_description(
# f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float()})")
#
# pbar.close()
#
def _validate_epoch(self, epoch):
"""
@ -221,19 +197,3 @@ class Trainer:
for data, target in self.dataloader_val:
# Implement the model validation steps here
pass
# # Example usage
# if __name__ == "__main__":
# # Assuming the following objects have already been correctly created and initialized:
# # model, optim, scheduler, dataloader_train, and dataloader_val.
# trainer = Trainer(
# max_epoch=10,
# model=model,
# optim=optim,
# scheduler=scheduler,
# dataloader_train=dataloader_train,
# dataloader_val=dataloader_val,
# output_dir='path_to_save_model',
# resume='path_to_checkpoint_if_any'
# )
# trainer.run()