mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add extract_token binary
This commit is contained in:
parent
49903ec044
commit
b372ab6d74
@ -290,6 +290,8 @@ class AutoModel:
|
||||
)
|
||||
time_speech_total = 0.0
|
||||
time_escape_total = 0.0
|
||||
count = 0
|
||||
log_interval = kwargs.get("log_interval", None)
|
||||
for beg_idx in range(0, num_samples, batch_size):
|
||||
end_idx = min(num_samples, beg_idx + batch_size)
|
||||
data_batch = data_list[beg_idx:end_idx]
|
||||
@ -325,8 +327,12 @@ class AutoModel:
|
||||
if pbar:
|
||||
pbar.update(batch_size)
|
||||
pbar.set_description(description)
|
||||
else:
|
||||
if log_interval is not None and count % log_interval == 0:
|
||||
logging.info(f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}")
|
||||
time_speech_total += batch_data_time
|
||||
time_escape_total += time_escape
|
||||
count += 1
|
||||
|
||||
if run_mode == "extract_token" and hasattr(model, "writer"):
|
||||
model.writer.close()
|
||||
|
||||
45
funasr/bin/extract_token.py
Normal file
45
funasr/bin/extract_token.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
import hydra
|
||||
import logging
|
||||
from omegaconf import DictConfig, OmegaConf, ListConfig
|
||||
import torch
|
||||
import torch.distributed
|
||||
from funasr.auto.auto_model import AutoModel
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(cfg: DictConfig):
|
||||
def to_plain_list(cfg_item):
|
||||
if isinstance(cfg_item, ListConfig):
|
||||
return OmegaConf.to_container(cfg_item, resolve=True)
|
||||
elif isinstance(cfg_item, DictConfig):
|
||||
return {k: to_plain_list(v) for k, v in cfg_item.items()}
|
||||
else:
|
||||
return cfg_item
|
||||
|
||||
kwargs = to_plain_list(cfg)
|
||||
|
||||
dist_rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
logging.basicConfig(
|
||||
level='INFO',
|
||||
format=f"[{os.uname()[1].split('.')[0]}]-[{dist_rank}/{world_size}] "
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
kwargs["input"] = kwargs["input"] + f"{dist_rank:02d}"
|
||||
kwargs["output_dir"] = os.path.join(kwargs["output_dir"], f"{dist_rank:02d}")
|
||||
kwargs["device"] = "cuda"
|
||||
kwargs["disable_pbar"] = True
|
||||
logging.info("start to extract {}.".format(kwargs["input"]))
|
||||
logging.info("save to {}.".format(kwargs["output_dir"]))
|
||||
logging.info("using device cuda:{}.".format(local_rank))
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
model = AutoModel(**kwargs)
|
||||
res = model.generate(input=kwargs["input"])
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_hydra()
|
||||
Loading…
Reference in New Issue
Block a user