mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
import os
|
|
import hydra
|
|
import logging
|
|
from omegaconf import DictConfig, OmegaConf, ListConfig
|
|
import torch
|
|
import torch.distributed
|
|
from funasr.auto.auto_model import AutoModel
|
|
|
|
|
|
|
|
@hydra.main(config_name=None, version_base=None)
|
|
def main_hydra(cfg: DictConfig):
|
|
def to_plain_list(cfg_item):
|
|
if isinstance(cfg_item, ListConfig):
|
|
return OmegaConf.to_container(cfg_item, resolve=True)
|
|
elif isinstance(cfg_item, DictConfig):
|
|
return {k: to_plain_list(v) for k, v in cfg_item.items()}
|
|
else:
|
|
return cfg_item
|
|
|
|
kwargs = to_plain_list(cfg)
|
|
|
|
dist_rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
logging.basicConfig(
|
|
level='INFO',
|
|
format=f"[{os.uname()[1].split('.')[0]}]-[{dist_rank}/{world_size}] "
|
|
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
|
)
|
|
local_rank = os.environ["LOCAL_RANK"]
|
|
kwargs["input"] = kwargs["input"] + f"{dist_rank:02d}"
|
|
kwargs["output_dir"] = os.path.join(kwargs["output_dir"], f"{dist_rank:02d}")
|
|
kwargs["device"] = "cuda"
|
|
kwargs["disable_pbar"] = True
|
|
logging.info("start to extract {}.".format(kwargs["input"]))
|
|
logging.info("save to {}.".format(kwargs["output_dir"]))
|
|
logging.info("using device cuda:{}.".format(local_rank))
|
|
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
|
model = AutoModel(**kwargs)
|
|
res = model.generate(input=kwargs["input"])
|
|
print(res)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main_hydra()
|