mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
This commit is contained in:
commit
8a47091504
@ -582,10 +582,16 @@ class Trainer:
|
||||
if num_batch_updates % batch_interval == 0:
|
||||
if options.use_pai and options.oss_bucket is not None:
|
||||
buffer = BytesIO()
|
||||
torch.save(model.state_dict(), buffer)
|
||||
if hasattr(model, "module"):
|
||||
torch.save(model.module.state_dict(), buffer)
|
||||
else:
|
||||
torch.save(model.state_dict(), buffer)
|
||||
options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
|
||||
else:
|
||||
torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
||||
if hasattr(model, "module"):
|
||||
torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
||||
else:
|
||||
torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
|
||||
|
||||
if distributed:
|
||||
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user