mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
f97f3e8dd5
commit
dce53b268b
@ -66,6 +66,7 @@ def main(**kwargs):
|
||||
# open tf32
|
||||
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
|
||||
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
|
||||
@ -83,7 +84,7 @@ def main(**kwargs):
|
||||
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
rank = dist.get_rank()
|
||||
# rank = dist.get_rank()
|
||||
|
||||
logging.info("Build model, frontend, tokenizer")
|
||||
device = kwargs.get("device", "cuda")
|
||||
|
||||
@ -78,7 +78,7 @@ class Trainer:
|
||||
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
||||
resume (str, optional): The file path to a checkpoint to resume training from.
|
||||
"""
|
||||
self.rank = kwargs.get("rank", 0)
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.world_size = world_size
|
||||
self.use_ddp = use_ddp
|
||||
|
||||
Loading…
Reference in New Issue
Block a user