add extract_token binary

This commit is contained in:
志浩 2024-09-24 22:27:25 +08:00
parent 49903ec044
commit b372ab6d74
2 changed files with 51 additions and 0 deletions

View File

@ -290,6 +290,8 @@ class AutoModel:
)
time_speech_total = 0.0
time_escape_total = 0.0
count = 0
log_interval = kwargs.get("log_interval", None)
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
@ -325,8 +327,12 @@ class AutoModel:
if pbar:
pbar.update(batch_size)
pbar.set_description(description)
else:
if log_interval is not None and count % log_interval == 0:
logging.info(f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}")
time_speech_total += batch_data_time
time_escape_total += time_escape
count += 1
if run_mode == "extract_token" and hasattr(model, "writer"):
model.writer.close()

View File

@ -0,0 +1,45 @@
import os
import hydra
import logging
from omegaconf import DictConfig, OmegaConf, ListConfig
import torch
import torch.distributed
from funasr.auto.auto_model import AutoModel
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
dist_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
logging.basicConfig(
level='INFO',
format=f"[{os.uname()[1].split('.')[0]}]-[{dist_rank}/{world_size}] "
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
local_rank = os.environ["LOCAL_RANK"]
kwargs["input"] = kwargs["input"] + f"{dist_rank:02d}"
kwargs["output_dir"] = os.path.join(kwargs["output_dir"], f"{dist_rank:02d}")
kwargs["device"] = "cuda"
kwargs["disable_pbar"] = True
logging.info("start to extract {}.".format(kwargs["input"]))
logging.info("save to {}.".format(kwargs["output_dir"]))
logging.info("using device cuda:{}.".format(local_rank))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
model = AutoModel(**kwargs)
res = model.generate(input=kwargs["input"])
print(res)
if __name__ == "__main__":
main_hydra()