From b372ab6d74d3c0729c8ea49f04b26d244739e2fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Tue, 24 Sep 2024 22:27:25 +0800 Subject: [PATCH] add extract_token binary --- funasr/auto/auto_model.py | 6 +++++ funasr/bin/extract_token.py | 45 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 funasr/bin/extract_token.py diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 7db1eb606..7d6b8ce59 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -290,6 +290,8 @@ class AutoModel: ) time_speech_total = 0.0 time_escape_total = 0.0 + count = 0 + log_interval = kwargs.get("log_interval", None) for beg_idx in range(0, num_samples, batch_size): end_idx = min(num_samples, beg_idx + batch_size) data_batch = data_list[beg_idx:end_idx] @@ -325,8 +327,12 @@ class AutoModel: if pbar: pbar.update(batch_size) pbar.set_description(description) + else: + if log_interval is not None and count % log_interval == 0: + logging.info(f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}") time_speech_total += batch_data_time time_escape_total += time_escape + count += 1 if run_mode == "extract_token" and hasattr(model, "writer"): model.writer.close() diff --git a/funasr/bin/extract_token.py b/funasr/bin/extract_token.py new file mode 100644 index 000000000..96696de8c --- /dev/null +++ b/funasr/bin/extract_token.py @@ -0,0 +1,45 @@ +import os +import hydra +import logging +from omegaconf import DictConfig, OmegaConf, ListConfig +import torch +import torch.distributed +from funasr.auto.auto_model import AutoModel + + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + kwargs = to_plain_list(cfg) + + dist_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + logging.basicConfig( + level='INFO', + format=f"[{os.uname()[1].split('.')[0]}]-[{dist_rank}/{world_size}] " + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + local_rank = os.environ["LOCAL_RANK"] + kwargs["input"] = kwargs["input"] + f"{dist_rank:02d}" + kwargs["output_dir"] = os.path.join(kwargs["output_dir"], f"{dist_rank:02d}") + kwargs["device"] = "cuda" + kwargs["disable_pbar"] = True + logging.info("start to extract {}.".format(kwargs["input"])) + logging.info("save to {}.".format(kwargs["output_dir"])) + logging.info("using device cuda:{}.".format(local_rank)) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + model = AutoModel(**kwargs) + res = model.generate(input=kwargs["input"]) + print(res) + + +if __name__ == "__main__": + main_hydra()