FunASR/funasr/bin/extract_token.py
2024-09-24 22:27:25 +08:00

46 lines
1.5 KiB
Python

import os
import hydra
import logging
from omegaconf import DictConfig, OmegaConf, ListConfig
import torch
import torch.distributed
from funasr.auto.auto_model import AutoModel
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
dist_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
logging.basicConfig(
level='INFO',
format=f"[{os.uname()[1].split('.')[0]}]-[{dist_rank}/{world_size}] "
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
local_rank = os.environ["LOCAL_RANK"]
kwargs["input"] = kwargs["input"] + f"{dist_rank:02d}"
kwargs["output_dir"] = os.path.join(kwargs["output_dir"], f"{dist_rank:02d}")
kwargs["device"] = "cuda"
kwargs["disable_pbar"] = True
logging.info("start to extract {}.".format(kwargs["input"]))
logging.info("save to {}.".format(kwargs["output_dir"]))
logging.info("using device cuda:{}.".format(local_rank))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
model = AutoModel(**kwargs)
res = model.generate(input=kwargs["input"])
print(res)
if __name__ == "__main__":
main_hydra()