This commit is contained in:
嘉渊 2023-04-25 01:23:14 +08:00
parent 622c025e10
commit 4c3e502cb8

View File

@ -183,14 +183,14 @@ class Trainer:
raise RuntimeError(
"Require torch>=1.6.0 for Automatic Mixed Precision"
)
if trainer_options.sharded_ddp:
if fairscale is None:
raise RuntimeError(
"Requiring fairscale. Do 'pip install fairscale'"
)
scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
else:
scaler = GradScaler()
# if trainer_options.sharded_ddp:
# if fairscale is None:
# raise RuntimeError(
# "Requiring fairscale. Do 'pip install fairscale'"
# )
# scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
# else:
scaler = GradScaler()
else:
scaler = None
@ -295,10 +295,10 @@ class Trainer:
)
elif isinstance(scheduler, AbsEpochStepScheduler):
scheduler.step()
if trainer_options.sharded_ddp:
for optimizer in optimizers:
if isinstance(optimizer, fairscale.optim.oss.OSS):
optimizer.consolidate_state_dict()
# if trainer_options.sharded_ddp:
# for optimizer in optimizers:
# if isinstance(optimizer, fairscale.optim.oss.OSS):
# optimizer.consolidate_state_dict()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
# 3. Report the results
@ -306,8 +306,8 @@ class Trainer:
if train_summary_writer is not None:
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
if trainer_options.use_wandb:
reporter.wandb_log()
# if trainer_options.use_wandb:
# reporter.wandb_log()
# save tensorboard on oss
if trainer_options.use_pai and train_summary_writer is not None:
@ -412,25 +412,25 @@ class Trainer:
"The best model has been updated: " + ", ".join(_improved)
)
log_model = (
trainer_options.wandb_model_log_interval > 0
and iepoch % trainer_options.wandb_model_log_interval == 0
)
if log_model and trainer_options.use_wandb:
import wandb
logging.info("Logging Model on this epoch :::::")
artifact = wandb.Artifact(
name=f"model_{wandb.run.id}",
type="model",
metadata={"improved": _improved},
)
artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
aliases = [
f"epoch-{iepoch}",
"best" if best_epoch == iepoch else "",
]
wandb.log_artifact(artifact, aliases=aliases)
# log_model = (
# trainer_options.wandb_model_log_interval > 0
# and iepoch % trainer_options.wandb_model_log_interval == 0
# )
# if log_model and trainer_options.use_wandb:
# import wandb
#
# logging.info("Logging Model on this epoch :::::")
# artifact = wandb.Artifact(
# name=f"model_{wandb.run.id}",
# type="model",
# metadata={"improved": _improved},
# )
# artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
# aliases = [
# f"epoch-{iepoch}",
# "best" if best_epoch == iepoch else "",
# ]
# wandb.log_artifact(artifact, aliases=aliases)
# 6. Remove the model files excluding n-best epoch and latest epoch
_removed = []
@ -529,9 +529,9 @@ class Trainer:
grad_clip = options.grad_clip
grad_clip_type = options.grad_clip_type
log_interval = options.log_interval
no_forward_run = options.no_forward_run
# no_forward_run = options.no_forward_run
ngpu = options.ngpu
use_wandb = options.use_wandb
# use_wandb = options.use_wandb
distributed = distributed_option.distributed
if log_interval is None:
@ -559,9 +559,9 @@ class Trainer:
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
if no_forward_run:
all_steps_are_invalid = False
continue
# if no_forward_run:
# all_steps_are_invalid = False
# continue
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
@ -737,8 +737,8 @@ class Trainer:
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
if summary_writer is not None:
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
if use_wandb:
reporter.wandb_log()
# if use_wandb:
# reporter.wandb_log()
if max_update_stop:
break
@ -760,7 +760,7 @@ class Trainer:
) -> None:
assert check_argument_types()
ngpu = options.ngpu
no_forward_run = options.no_forward_run
# no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
@ -776,8 +776,8 @@ class Trainer:
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
if no_forward_run:
continue
# if no_forward_run:
# continue
retval = model(**batch)
if isinstance(retval, dict):