mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #376 from alibaba-damo-academy/dev_wjm3
update saving model by num_update
This commit is contained in:
commit
975ac752f7
@ -467,7 +467,7 @@ class AbsTask(ABC):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch_interval",
|
"--batch_interval",
|
||||||
type=int,
|
type=int,
|
||||||
default=10000,
|
default=-1,
|
||||||
help="The batch interval for saving model.",
|
help="The batch interval for saving model.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
|||||||
@ -571,8 +571,7 @@ class Trainer:
|
|||||||
#ouput dir
|
#ouput dir
|
||||||
output_dir = Path(options.output_dir)
|
output_dir = Path(options.output_dir)
|
||||||
#batch interval
|
#batch interval
|
||||||
batch_interval = options.batch_interval
|
batch_interval = options.batch_interval
|
||||||
assert batch_interval > 0
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
for iiter, (_, batch) in enumerate(
|
for iiter, (_, batch) in enumerate(
|
||||||
@ -580,14 +579,17 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
assert isinstance(batch, dict), type(batch)
|
assert isinstance(batch, dict), type(batch)
|
||||||
|
|
||||||
if rank == 0:
|
if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
|
||||||
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
|
||||||
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
|
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
|
||||||
if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
|
if num_batch_updates % batch_interval == 0:
|
||||||
buffer = BytesIO()
|
if options.use_pai and options.oss_bucket is not None:
|
||||||
torch.save(model.state_dict(), buffer)
|
buffer = BytesIO()
|
||||||
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue())
|
torch.save(model.state_dict(), buffer)
|
||||||
|
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
||||||
|
else:
|
||||||
|
torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||||
if iterator_stop > 0:
|
if iterator_stop > 0:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user