mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr2
This commit is contained in:
parent
d6aa93946e
commit
15868f6230
@ -46,7 +46,7 @@ def main(kwargs: DictConfig):
|
||||
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
# Check if we are using DDP or FSDP
|
||||
use_ddp = 'WORLD_SIZE' in os.environ
|
||||
use_ddp = 'WORLD_SIZE' in os.environ and os.environ["WORLD_SIZE"] > 1
|
||||
use_fsdp = kwargs.get("use_fsdp", None)
|
||||
if use_ddp or use_fsdp:
|
||||
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
|
||||
@ -109,7 +109,8 @@ def main(kwargs: DictConfig):
|
||||
|
||||
if use_ddp:
|
||||
model = model.cuda(local_rank)
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
model = DDP(model, device_ids=[local_rank],
|
||||
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
|
||||
elif use_fsdp:
|
||||
model = FSDP(model).cuda(local_rank)
|
||||
else:
|
||||
@ -157,13 +158,6 @@ def main(kwargs: DictConfig):
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
|
||||
def train(epoch, model, op):
|
||||
pass
|
||||
|
||||
def val():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -5,6 +5,7 @@ import logging
|
||||
from tqdm import tqdm
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
from funasr.torch_utils.recursive_op import recursive_average
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
@ -56,6 +57,8 @@ class Trainer:
|
||||
self.start_epoch = 1
|
||||
self.max_epoch = kwargs.get('max_epoch', 100)
|
||||
self.local_rank = local_rank
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
self.device = torch.device("cuda", local_rank)
|
||||
@ -113,7 +116,7 @@ class Trainer:
|
||||
# self._validate_epoch(epoch)
|
||||
if dist.get_rank() == 0:
|
||||
self._save_checkpoint(epoch)
|
||||
# self.scheduler.step()
|
||||
self.scheduler.step()
|
||||
|
||||
def _train_epoch(self, epoch):
|
||||
"""
|
||||
@ -126,24 +129,34 @@ class Trainer:
|
||||
dynamic_ncols=True)
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
accumulation_steps = self.kwargs.get("accumulation_steps", 1)
|
||||
accum_grad = self.kwargs.get("accum_grad", 1)
|
||||
# Initialize the gradient accumulation
|
||||
self.optim.zero_grad()
|
||||
|
||||
for batch_idx, batch in enumerate(self.dataloader_train):
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
|
||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
retval = self.model(**batch)
|
||||
loss, stats, weight = retval
|
||||
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accumulation_steps
|
||||
loss = loss / accum_grad
|
||||
loss.backward()
|
||||
|
||||
# Perform an optimizer step only after accumulating enough gradients
|
||||
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
|
||||
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
|
||||
# Perform gradient clipping if it is set
|
||||
if self.kwargs.get("grad_clip", None) is not None:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
@ -170,43 +183,6 @@ class Trainer:
|
||||
f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float()})")
|
||||
|
||||
pbar.close()
|
||||
|
||||
# def _train_epoch(self, epoch):
|
||||
# """
|
||||
# Defines the training process for a single epoch.
|
||||
# Should be implemented with the actual model training steps.
|
||||
#
|
||||
# Args:
|
||||
# epoch (int): The current epoch number.
|
||||
# """
|
||||
# self.model.train()
|
||||
# pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
|
||||
# for batch_idx, batch in enumerate(self.dataloader_train):
|
||||
# batch = to_device(batch, "cpu")
|
||||
# retval = self.model(**batch)
|
||||
# loss, stats, weight = retval
|
||||
# self.optim.zero_grad()
|
||||
# loss.backward()
|
||||
#
|
||||
# # compute the gradient norm to check if it is normal or not
|
||||
# grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
# self.model.parameters(),
|
||||
# max_norm=self.kwargs.get("grad_clip", 10.0),
|
||||
# norm_type=self.kwargs.get("grad_clip_type", 2.0),
|
||||
# )
|
||||
# if not torch.isfinite(grad_norm):
|
||||
# logging.warning(
|
||||
# f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
# )
|
||||
# continue
|
||||
# self.optim.step()
|
||||
# self.scheduler.step()
|
||||
# pbar.update(1)
|
||||
# pbar.set_description(
|
||||
# f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float()})")
|
||||
#
|
||||
# pbar.close()
|
||||
#
|
||||
|
||||
def _validate_epoch(self, epoch):
|
||||
"""
|
||||
@ -221,19 +197,3 @@ class Trainer:
|
||||
for data, target in self.dataloader_val:
|
||||
# Implement the model validation steps here
|
||||
pass
|
||||
|
||||
# # Example usage
|
||||
# if __name__ == "__main__":
|
||||
# # Assuming the following objects have already been correctly created and initialized:
|
||||
# # model, optim, scheduler, dataloader_train, and dataloader_val.
|
||||
# trainer = Trainer(
|
||||
# max_epoch=10,
|
||||
# model=model,
|
||||
# optim=optim,
|
||||
# scheduler=scheduler,
|
||||
# dataloader_train=dataloader_train,
|
||||
# dataloader_val=dataloader_val,
|
||||
# output_dir='path_to_save_model',
|
||||
# resume='path_to_checkpoint_if_any'
|
||||
# )
|
||||
# trainer.run()
|
||||
Loading…
Reference in New Issue
Block a user