mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
622c025e10
commit
4c3e502cb8
@ -183,14 +183,14 @@ class Trainer:
|
||||
raise RuntimeError(
|
||||
"Require torch>=1.6.0 for Automatic Mixed Precision"
|
||||
)
|
||||
if trainer_options.sharded_ddp:
|
||||
if fairscale is None:
|
||||
raise RuntimeError(
|
||||
"Requiring fairscale. Do 'pip install fairscale'"
|
||||
)
|
||||
scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
|
||||
else:
|
||||
scaler = GradScaler()
|
||||
# if trainer_options.sharded_ddp:
|
||||
# if fairscale is None:
|
||||
# raise RuntimeError(
|
||||
# "Requiring fairscale. Do 'pip install fairscale'"
|
||||
# )
|
||||
# scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
|
||||
# else:
|
||||
scaler = GradScaler()
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
@ -295,10 +295,10 @@ class Trainer:
|
||||
)
|
||||
elif isinstance(scheduler, AbsEpochStepScheduler):
|
||||
scheduler.step()
|
||||
if trainer_options.sharded_ddp:
|
||||
for optimizer in optimizers:
|
||||
if isinstance(optimizer, fairscale.optim.oss.OSS):
|
||||
optimizer.consolidate_state_dict()
|
||||
# if trainer_options.sharded_ddp:
|
||||
# for optimizer in optimizers:
|
||||
# if isinstance(optimizer, fairscale.optim.oss.OSS):
|
||||
# optimizer.consolidate_state_dict()
|
||||
|
||||
if not distributed_option.distributed or distributed_option.dist_rank == 0:
|
||||
# 3. Report the results
|
||||
@ -306,8 +306,8 @@ class Trainer:
|
||||
if train_summary_writer is not None:
|
||||
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
|
||||
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
|
||||
if trainer_options.use_wandb:
|
||||
reporter.wandb_log()
|
||||
# if trainer_options.use_wandb:
|
||||
# reporter.wandb_log()
|
||||
|
||||
# save tensorboard on oss
|
||||
if trainer_options.use_pai and train_summary_writer is not None:
|
||||
@ -412,25 +412,25 @@ class Trainer:
|
||||
"The best model has been updated: " + ", ".join(_improved)
|
||||
)
|
||||
|
||||
log_model = (
|
||||
trainer_options.wandb_model_log_interval > 0
|
||||
and iepoch % trainer_options.wandb_model_log_interval == 0
|
||||
)
|
||||
if log_model and trainer_options.use_wandb:
|
||||
import wandb
|
||||
|
||||
logging.info("Logging Model on this epoch :::::")
|
||||
artifact = wandb.Artifact(
|
||||
name=f"model_{wandb.run.id}",
|
||||
type="model",
|
||||
metadata={"improved": _improved},
|
||||
)
|
||||
artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
||||
aliases = [
|
||||
f"epoch-{iepoch}",
|
||||
"best" if best_epoch == iepoch else "",
|
||||
]
|
||||
wandb.log_artifact(artifact, aliases=aliases)
|
||||
# log_model = (
|
||||
# trainer_options.wandb_model_log_interval > 0
|
||||
# and iepoch % trainer_options.wandb_model_log_interval == 0
|
||||
# )
|
||||
# if log_model and trainer_options.use_wandb:
|
||||
# import wandb
|
||||
#
|
||||
# logging.info("Logging Model on this epoch :::::")
|
||||
# artifact = wandb.Artifact(
|
||||
# name=f"model_{wandb.run.id}",
|
||||
# type="model",
|
||||
# metadata={"improved": _improved},
|
||||
# )
|
||||
# artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
|
||||
# aliases = [
|
||||
# f"epoch-{iepoch}",
|
||||
# "best" if best_epoch == iepoch else "",
|
||||
# ]
|
||||
# wandb.log_artifact(artifact, aliases=aliases)
|
||||
|
||||
# 6. Remove the model files excluding n-best epoch and latest epoch
|
||||
_removed = []
|
||||
@ -529,9 +529,9 @@ class Trainer:
|
||||
grad_clip = options.grad_clip
|
||||
grad_clip_type = options.grad_clip_type
|
||||
log_interval = options.log_interval
|
||||
no_forward_run = options.no_forward_run
|
||||
# no_forward_run = options.no_forward_run
|
||||
ngpu = options.ngpu
|
||||
use_wandb = options.use_wandb
|
||||
# use_wandb = options.use_wandb
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
if log_interval is None:
|
||||
@ -559,9 +559,9 @@ class Trainer:
|
||||
break
|
||||
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
if no_forward_run:
|
||||
all_steps_are_invalid = False
|
||||
continue
|
||||
# if no_forward_run:
|
||||
# all_steps_are_invalid = False
|
||||
# continue
|
||||
|
||||
with autocast(scaler is not None):
|
||||
with reporter.measure_time("forward_time"):
|
||||
@ -737,8 +737,8 @@ class Trainer:
|
||||
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
|
||||
if summary_writer is not None:
|
||||
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
|
||||
if use_wandb:
|
||||
reporter.wandb_log()
|
||||
# if use_wandb:
|
||||
# reporter.wandb_log()
|
||||
|
||||
if max_update_stop:
|
||||
break
|
||||
@ -760,7 +760,7 @@ class Trainer:
|
||||
) -> None:
|
||||
assert check_argument_types()
|
||||
ngpu = options.ngpu
|
||||
no_forward_run = options.no_forward_run
|
||||
# no_forward_run = options.no_forward_run
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
model.eval()
|
||||
@ -776,8 +776,8 @@ class Trainer:
|
||||
break
|
||||
|
||||
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
|
||||
if no_forward_run:
|
||||
continue
|
||||
# if no_forward_run:
|
||||
# continue
|
||||
|
||||
retval = model(**batch)
|
||||
if isinstance(retval, dict):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user