mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
deepspeed
This commit is contained in:
parent
f68ae892be
commit
d3ff05837b
@ -133,7 +133,7 @@ def main(**kwargs):
|
||||
kwargs["device"] = next(model.parameters()).device
|
||||
trainer.device = kwargs["device"]
|
||||
|
||||
model, optim, scheduler = trainer.warp_optim_scheduler(model, kwargs)
|
||||
model, optim, scheduler = trainer.warp_optim_scheduler(model, **kwargs)
|
||||
|
||||
# dataset
|
||||
logging.info("Build dataloader")
|
||||
|
||||
@ -100,7 +100,9 @@ class MultiHeadedAttention(nn.Module):
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
@ -269,7 +271,9 @@ class MultiHeadedAttentionSANM(nn.Module):
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
@ -673,7 +677,9 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
# logging.info(
|
||||
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
@ -858,7 +864,9 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
|
||||
@ -146,7 +146,9 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
else:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
|
||||
qk = qk.masked_fill(mask, min_value)
|
||||
|
||||
qk = qk.float()
|
||||
|
||||
@ -112,7 +112,9 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
else:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
|
||||
qk = qk.masked_fill(mask, min_value)
|
||||
|
||||
qk = qk.float()
|
||||
|
||||
@ -78,7 +78,7 @@ class Trainer:
|
||||
self.world_size = world_size
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
self.use_deepspeed = use_deepspeed
|
||||
|
||||
self.device = kwargs.get("device", "cuda")
|
||||
|
||||
self.output_dir = output_dir
|
||||
@ -137,6 +137,9 @@ class Trainer:
|
||||
except:
|
||||
self.writer = None
|
||||
|
||||
self.use_deepspeed = use_deepspeed
|
||||
self.deepspeed_config = kwargs.get("deepspeed_config", "")
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
epoch,
|
||||
@ -443,7 +446,8 @@ class Trainer:
|
||||
iterator_stop = torch.tensor(0).to(self.device)
|
||||
|
||||
def forward_step(self, model, batch, loss_dict={}):
|
||||
with maybe_autocast(self.use_fp16):
|
||||
dtype = torch.bfloat16
|
||||
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
|
||||
retval = model(**batch)
|
||||
|
||||
loss, stats, weight = retval
|
||||
@ -465,7 +469,7 @@ class Trainer:
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
|
||||
def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
|
||||
|
||||
if self.use_deepspeed:
|
||||
model.step()
|
||||
@ -613,7 +617,7 @@ class Trainer:
|
||||
loss = loss_dict["loss"].detach().cpu().item()
|
||||
epoch = loss_dict["epoch"]
|
||||
batch_idx = loss_dict["batch_idx"]
|
||||
step_in_epoch = loss_dict["step_in_epoch"]
|
||||
step_in_epoch = self.step_in_epoch
|
||||
batch_num_epoch = loss_dict["batch_num_epoch"]
|
||||
lr = loss_dict["lr"]
|
||||
|
||||
@ -732,36 +736,18 @@ class Trainer:
|
||||
"find_unused_parameters", False
|
||||
),
|
||||
)
|
||||
# elif self.use_fsdp:
|
||||
# # model = FSDP(model).cuda(local_rank)
|
||||
#
|
||||
# def custom_auto_wrap_policy(
|
||||
# module: nn.Module,
|
||||
# recurse: bool,
|
||||
# nonwrapped_numel: int,
|
||||
# # Additional custom arguments
|
||||
# min_num_params: int = int(1e8),
|
||||
# ) -> bool:
|
||||
# # 根据自定义逻辑决定是否包装模块
|
||||
# is_large = unwrapped_params >= min_num_params
|
||||
# requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
|
||||
# return is_large and requires_grad_uniform
|
||||
#
|
||||
# # Configure a custom `min_num_params`
|
||||
# my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
|
||||
# torch.cuda.set_device(local_rank)
|
||||
# model = FSDP(
|
||||
# model,
|
||||
# auto_wrap_policy=custom_auto_wrap_policy,
|
||||
# mixed_precision=None,
|
||||
# device_id=torch.cuda.current_device(),
|
||||
# )
|
||||
|
||||
else:
|
||||
model = model.to(device=kwargs.get("device", "cuda"))
|
||||
|
||||
return model
|
||||
|
||||
def warp_optim_scheduler(self, model, **kwargs):
|
||||
from funasr.optimizers import optim_classes
|
||||
from funasr.schedulers import scheduler_classes
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
import json
|
||||
import deepspeed
|
||||
|
||||
# optim
|
||||
logging.info("Build optim")
|
||||
@ -777,15 +763,16 @@ class Trainer:
|
||||
scheduler_class = scheduler_classes.get(scheduler)
|
||||
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
||||
|
||||
if use_deepspeed:
|
||||
deepspeed_config = kwargs.get("deepspeed_config", "")
|
||||
with open(deepspeed_config, "r") as fin:
|
||||
if self.use_deepspeed:
|
||||
|
||||
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
|
||||
with open(self.deepspeed_config, "r") as fin:
|
||||
ds_configs = json.load(fin)
|
||||
if "optimizer" in ds_configs:
|
||||
# NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
|
||||
# extremely useful when enable cpu_offload, DeepspeedCpuAdam
|
||||
# could be 4~5x faster than torch native adam
|
||||
deepspeed_config = None
|
||||
optim = None
|
||||
if "scheduler" in ds_configs:
|
||||
scheduler = None
|
||||
else:
|
||||
@ -793,7 +780,6 @@ class Trainer:
|
||||
def scheduler(opt):
|
||||
return scheduler_class(opt, **kwargs.get("scheduler_conf"))
|
||||
|
||||
args = OmegaConf.create({"deepspeed_config": deepspeed_config})
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=args,
|
||||
model=model,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user