mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
code update
This commit is contained in:
parent
3fcb5dcfed
commit
1233c0d3ff
@ -18,5 +18,5 @@ frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-co
|
||||
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
|
||||
|
||||
for batch_idx, fbank_dict in enumerate(fbanks):
|
||||
res = model(**fbank_dict)
|
||||
print(res)
|
||||
res = model(**fbank_dict)
|
||||
print(res)
|
||||
@ -309,10 +309,7 @@ class AutoModel:
|
||||
if not len(sorted_data):
|
||||
logging.info("decoding, utt: {}, empty speech".format(key))
|
||||
continue
|
||||
|
||||
|
||||
# if kwargs["device"] == "cpu":
|
||||
# batch_size = 0
|
||||
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
||||
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
|
||||
|
||||
|
||||
@ -1,178 +1,180 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from collections.abc import Sequence
|
||||
import torch
|
||||
import hydra
|
||||
import logging
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
import torch.distributed as dist
|
||||
from collections.abc import Sequence
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.models.lora.utils import mark_only_lora_as_trainable
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.optimizers import optim_classes
|
||||
from funasr.train_utils.trainer import Trainer
|
||||
from funasr.schedulers import scheduler_classes
|
||||
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
||||
from funasr.train_utils.initialize import initialize
|
||||
from funasr.download.download_from_hub import download_model
|
||||
from funasr.models.lora.utils import mark_only_lora_as_trainable
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
||||
# from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
# from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
# from funasr.tokenizer.funtoken import build_tokenizer
|
||||
from funasr.train_utils.trainer import Trainer
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from funasr.download.download_from_hub import download_model
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(kwargs: DictConfig):
|
||||
if kwargs.get("debug", False):
|
||||
import pdb; pdb.set_trace()
|
||||
if kwargs.get("debug", False):
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
assert "model" in kwargs
|
||||
if "model_conf" not in kwargs:
|
||||
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
||||
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
|
||||
|
||||
assert "model" in kwargs
|
||||
if "model_conf" not in kwargs:
|
||||
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
||||
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
|
||||
|
||||
|
||||
main(**kwargs)
|
||||
main(**kwargs)
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
# preprocess_config(kwargs)
|
||||
# import pdb; pdb.set_trace()
|
||||
# set random seed
|
||||
tables.print()
|
||||
set_all_random_seed(kwargs.get("seed", 0))
|
||||
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
|
||||
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
|
||||
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
||||
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
# Check if we are using DDP or FSDP
|
||||
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
|
||||
use_fsdp = kwargs.get("use_fsdp", None)
|
||||
if use_ddp or use_fsdp:
|
||||
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# save config.yaml
|
||||
if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
|
||||
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
|
||||
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
|
||||
OmegaConf.save(config=kwargs, f=yaml_file)
|
||||
logging.info("config.yaml is saved to: %s", yaml_file)
|
||||
# preprocess_config(kwargs)
|
||||
# import pdb; pdb.set_trace()
|
||||
# set random seed
|
||||
tables.print()
|
||||
set_all_random_seed(kwargs.get("seed", 0))
|
||||
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
|
||||
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
|
||||
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
||||
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
# Check if we are using DDP or FSDP
|
||||
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
|
||||
use_fsdp = kwargs.get("use_fsdp", None)
|
||||
if use_ddp or use_fsdp:
|
||||
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# save config.yaml
|
||||
if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
|
||||
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
|
||||
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
|
||||
OmegaConf.save(config=kwargs, f=yaml_file)
|
||||
logging.info("config.yaml is saved to: %s", yaml_file)
|
||||
|
||||
tokenizer = kwargs.get("tokenizer", None)
|
||||
if tokenizer is not None:
|
||||
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
# build frontend if frontend is none None
|
||||
frontend = kwargs.get("frontend", None)
|
||||
if frontend is not None:
|
||||
frontend_class = tables.frontend_classes.get(frontend)
|
||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||
kwargs["frontend"] = frontend
|
||||
kwargs["input_size"] = frontend.output_size()
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"])
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||
tokenizer = kwargs.get("tokenizer", None)
|
||||
if tokenizer is not None:
|
||||
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
# build frontend if frontend is none None
|
||||
frontend = kwargs.get("frontend", None)
|
||||
if frontend is not None:
|
||||
frontend_class = tables.frontend_classes.get(frontend)
|
||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||
kwargs["frontend"] = frontend
|
||||
kwargs["input_size"] = frontend.output_size()
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"])
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||
|
||||
|
||||
|
||||
# init_param
|
||||
init_param = kwargs.get("init_param", None)
|
||||
if init_param is not None:
|
||||
if not isinstance(init_param, (list, tuple)):
|
||||
init_param = (init_param,)
|
||||
logging.info("init_param is not None: %s", init_param)
|
||||
for p in init_param:
|
||||
logging.info(f"Loading pretrained params from {p}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
init_param=p,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
)
|
||||
else:
|
||||
initialize(model, kwargs.get("init", "kaiming_normal"))
|
||||
# init_param
|
||||
init_param = kwargs.get("init_param", None)
|
||||
if init_param is not None:
|
||||
if not isinstance(init_param, (list, tuple)):
|
||||
init_param = (init_param,)
|
||||
logging.info("init_param is not None: %s", init_param)
|
||||
for p in init_param:
|
||||
logging.info(f"Loading pretrained params from {p}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
init_param=p,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
)
|
||||
else:
|
||||
initialize(model, kwargs.get("init", "kaiming_normal"))
|
||||
|
||||
|
||||
# freeze_param
|
||||
freeze_param = kwargs.get("freeze_param", None)
|
||||
if freeze_param is not None:
|
||||
freeze_param = eval(freeze_param)
|
||||
if isinstance(freeze_param, Sequence):
|
||||
freeze_param = (freeze_param,)
|
||||
logging.info("freeze_param is not None: %s", freeze_param)
|
||||
for t in freeze_param:
|
||||
for k, p in model.named_parameters():
|
||||
if k.startswith(t + ".") or k == t:
|
||||
logging.info(f"Setting {k}.requires_grad = False")
|
||||
p.requires_grad = False
|
||||
|
||||
# freeze_param
|
||||
freeze_param = kwargs.get("freeze_param", None)
|
||||
if freeze_param is not None:
|
||||
freeze_param = eval(freeze_param)
|
||||
if isinstance(freeze_param, Sequence):
|
||||
freeze_param = (freeze_param,)
|
||||
logging.info("freeze_param is not None: %s", freeze_param)
|
||||
for t in freeze_param:
|
||||
for k, p in model.named_parameters():
|
||||
if k.startswith(t + ".") or k == t:
|
||||
logging.info(f"Setting {k}.requires_grad = False")
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
if use_ddp:
|
||||
model = model.cuda(local_rank)
|
||||
model = DDP(model, device_ids=[local_rank],
|
||||
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
|
||||
elif use_fsdp:
|
||||
model = FSDP(model).cuda(local_rank)
|
||||
else:
|
||||
model = model.to(device=kwargs.get("device", "cuda"))
|
||||
|
||||
|
||||
# optim
|
||||
optim = kwargs.get("optim", "adam")
|
||||
assert optim in optim_classes
|
||||
optim_class = optim_classes.get(optim)
|
||||
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
|
||||
|
||||
# scheduler
|
||||
scheduler = kwargs.get("scheduler", "warmuplr")
|
||||
assert scheduler in scheduler_classes
|
||||
scheduler_class = scheduler_classes.get(scheduler)
|
||||
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
||||
if use_ddp:
|
||||
model = model.cuda(local_rank)
|
||||
model = DDP(model, device_ids=[local_rank],
|
||||
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
|
||||
elif use_fsdp:
|
||||
model = FSDP(model).cuda(local_rank)
|
||||
else:
|
||||
model = model.to(device=kwargs.get("device", "cuda"))
|
||||
|
||||
|
||||
# optim
|
||||
optim = kwargs.get("optim", "adam")
|
||||
assert optim in optim_classes
|
||||
optim_class = optim_classes.get(optim)
|
||||
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
|
||||
|
||||
# scheduler
|
||||
scheduler = kwargs.get("scheduler", "warmuplr")
|
||||
assert scheduler in scheduler_classes
|
||||
scheduler_class = scheduler_classes.get(scheduler)
|
||||
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# dataset
|
||||
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
||||
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# dataset
|
||||
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
||||
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
||||
|
||||
# dataloader
|
||||
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
||||
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
||||
if batch_sampler is not None:
|
||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||
collate_fn=dataset_tr.collator,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
||||
pin_memory=True)
|
||||
|
||||
# dataloader
|
||||
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
||||
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
||||
if batch_sampler is not None:
|
||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||
collate_fn=dataset_tr.collator,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
||||
pin_memory=True)
|
||||
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optim=optim,
|
||||
scheduler=scheduler,
|
||||
dataloader_train=dataloader_tr,
|
||||
dataloader_val=None,
|
||||
local_rank=local_rank,
|
||||
use_ddp=use_ddp,
|
||||
use_fsdp=use_fsdp,
|
||||
**kwargs.get("train_conf"),
|
||||
)
|
||||
trainer.run()
|
||||
|
||||
if use_ddp or use_fsdp:
|
||||
torch.distributed.destroy_process_group()
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optim=optim,
|
||||
scheduler=scheduler,
|
||||
dataloader_train=dataloader_tr,
|
||||
dataloader_val=None,
|
||||
local_rank=local_rank,
|
||||
use_ddp=use_ddp,
|
||||
use_fsdp=use_fsdp,
|
||||
**kwargs.get("train_conf"),
|
||||
)
|
||||
trainer.run()
|
||||
|
||||
if use_ddp or use_fsdp:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_hydra()
|
||||
main_hydra()
|
||||
@ -1,102 +1,93 @@
|
||||
import torch
|
||||
import json
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import kaldiio
|
||||
import librosa
|
||||
import torchaudio
|
||||
import time
|
||||
import logging
|
||||
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "AudioDataset")
|
||||
class AudioDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
AudioDataset
|
||||
"""
|
||||
def __init__(self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||
self.index_ds = index_ds_class(path)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
self.tokenizer = tokenizer
|
||||
"""
|
||||
AudioDataset
|
||||
"""
|
||||
def __init__(self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||
self.index_ds = index_ds_class(path)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.int_pad_value = int_pad_value
|
||||
self.float_pad_value = float_pad_value
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_source_len(item)
|
||||
|
||||
def get_target_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_target_len(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.index_ds[index]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
|
||||
self.int_pad_value = int_pad_value
|
||||
self.float_pad_value = float_pad_value
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_source_len(item)
|
||||
|
||||
def get_target_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_target_len(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.index_ds[index]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
|
||||
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
ids = self.tokenizer.encode(target)
|
||||
ids_lengths = len(ids)
|
||||
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
ids = self.tokenizer.encode(target)
|
||||
ids_lengths = len(ids)
|
||||
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
|
||||
|
||||
return {"speech": speech[0, :, :],
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
}
|
||||
|
||||
|
||||
def collator(self, samples: list=None):
|
||||
return {"speech": speech[0, :, :],
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
}
|
||||
|
||||
|
||||
def collator(self, samples: list=None):
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
for key in sample.keys():
|
||||
if key not in outputs:
|
||||
outputs[key] = []
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
for key, data_list in outputs.items():
|
||||
if data_list[0].dtype == torch.int64:
|
||||
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
for key in sample.keys():
|
||||
if key not in outputs:
|
||||
outputs[key] = []
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
for key, data_list in outputs.items():
|
||||
if data_list[0].dtype == torch.int64:
|
||||
|
||||
pad_value = self.int_pad_value
|
||||
else:
|
||||
pad_value = self.float_pad_value
|
||||
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
|
||||
return outputs
|
||||
pad_value = self.int_pad_value
|
||||
else:
|
||||
pad_value = self.float_pad_value
|
||||
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -1,64 +1,64 @@
|
||||
import torch
|
||||
import json
|
||||
import torch.distributed as dist
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("index_ds_classes", "IndexDSJsonl")
|
||||
class IndexDSJsonl(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
|
||||
contents = []
|
||||
with open(path, encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())
|
||||
if "text" in data: # for sft
|
||||
self.contents.append(data['text'])
|
||||
if "source" in data: # for speech lab pretrain
|
||||
prompt = data["prompt"]
|
||||
source = data["source"]
|
||||
target = data["target"]
|
||||
source_len = data["source_len"]
|
||||
target_len = data["target_len"]
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
|
||||
contents = []
|
||||
with open(path, encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())
|
||||
if "text" in data: # for sft
|
||||
self.contents.append(data['text'])
|
||||
if "source" in data: # for speech lab pretrain
|
||||
prompt = data["prompt"]
|
||||
source = data["source"]
|
||||
target = data["target"]
|
||||
source_len = data["source_len"]
|
||||
target_len = data["target_len"]
|
||||
|
||||
contents.append({"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
"source_len": source_len,
|
||||
"target_len": target_len,
|
||||
}
|
||||
)
|
||||
|
||||
self.contents = []
|
||||
total_num = len(contents)
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
logging.warning("distributed is not initialized, only single shard")
|
||||
num_per_rank = total_num // world_size
|
||||
|
||||
# rank = 0
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
|
||||
|
||||
logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
|
||||
contents.append({"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
"source_len": source_len,
|
||||
"target_len": target_len,
|
||||
}
|
||||
)
|
||||
|
||||
self.contents = []
|
||||
total_num = len(contents)
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
logging.warning("distributed is not initialized, only single shard")
|
||||
num_per_rank = total_num // world_size
|
||||
|
||||
# rank = 0
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
|
||||
|
||||
logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.contents[index]
|
||||
|
||||
def get_source_len(self, data_dict):
|
||||
return data_dict["source_len"]
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.contents[index]
|
||||
|
||||
def get_source_len(self, data_dict):
|
||||
return data_dict["source_len"]
|
||||
|
||||
def get_target_len(self, data_dict):
|
||||
|
||||
return data_dict["target_len"] if "target_len" in data_dict else 0
|
||||
def get_target_len(self, data_dict):
|
||||
|
||||
return data_dict["target_len"] if "target_len" in data_dict else 0
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from funasr.register import tables
|
||||
@ -7,74 +6,74 @@ from funasr.register import tables
|
||||
|
||||
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
||||
class BatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = False,
|
||||
shuffle: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = batch_size
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 5000)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
sample_len_cur = self.dataset.get_source_len(idx_map) + \
|
||||
self.dataset.get_target_len(idx_map)
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for item in datalen_with_index_sort:
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
if self.batch_type == 'length':
|
||||
max_token_padding *= max_token_cur
|
||||
if max_token_padding <= self.batch_size:
|
||||
batch.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
yield batch
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = False,
|
||||
shuffle: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
self.batch_type = batch_type
|
||||
self.batch_size = batch_size
|
||||
self.buffer_size = buffer_size
|
||||
self.max_token_length = kwargs.get("max_token_length", 5000)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
datalen_with_index = []
|
||||
for i in range(self.buffer_size):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
sample_len_cur = self.dataset.get_source_len(idx_map) + \
|
||||
self.dataset.get_target_len(idx_map)
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for item in datalen_with_index_sort:
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
if self.batch_type == 'length':
|
||||
max_token_padding *= max_token_cur
|
||||
if max_token_padding <= self.batch_size:
|
||||
batch.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
yield batch
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
@ -1,110 +1,111 @@
|
||||
import json
|
||||
import os
|
||||
import json
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
|
||||
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
|
||||
|
||||
|
||||
def download_model(**kwargs):
|
||||
model_hub = kwargs.get("model_hub", "ms")
|
||||
if model_hub == "ms":
|
||||
kwargs = download_from_ms(**kwargs)
|
||||
|
||||
return kwargs
|
||||
model_hub = kwargs.get("model_hub", "ms")
|
||||
if model_hub == "ms":
|
||||
kwargs = download_from_ms(**kwargs)
|
||||
|
||||
return kwargs
|
||||
|
||||
def download_from_ms(**kwargs):
|
||||
model_or_path = kwargs.get("model")
|
||||
if model_or_path in name_maps_ms:
|
||||
model_or_path = name_maps_ms[model_or_path]
|
||||
model_revision = kwargs.get("model_revision")
|
||||
if not os.path.exists(model_or_path):
|
||||
model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
|
||||
kwargs["model_path"] = model_or_path
|
||||
|
||||
config = os.path.join(model_or_path, "config.yaml")
|
||||
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
kwargs = OmegaConf.merge(config, kwargs)
|
||||
init_param = os.path.join(model_or_path, "model.pb")
|
||||
kwargs["init_param"] = init_param
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
|
||||
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
||||
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
||||
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
||||
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
||||
kwargs["model"] = config["model"]
|
||||
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
||||
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
||||
elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
||||
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
|
||||
conf_json = json.load(f)
|
||||
cfg = {}
|
||||
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
||||
cfg.update(kwargs)
|
||||
config = OmegaConf.load(cfg["config"])
|
||||
kwargs = OmegaConf.merge(config, cfg)
|
||||
kwargs["model"] = config["model"]
|
||||
return OmegaConf.to_container(kwargs, resolve=True)
|
||||
model_or_path = kwargs.get("model")
|
||||
if model_or_path in name_maps_ms:
|
||||
model_or_path = name_maps_ms[model_or_path]
|
||||
model_revision = kwargs.get("model_revision")
|
||||
if not os.path.exists(model_or_path):
|
||||
model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
|
||||
kwargs["model_path"] = model_or_path
|
||||
|
||||
config = os.path.join(model_or_path, "config.yaml")
|
||||
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
kwargs = OmegaConf.merge(config, kwargs)
|
||||
init_param = os.path.join(model_or_path, "model.pb")
|
||||
kwargs["init_param"] = init_param
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
|
||||
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
||||
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
||||
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
||||
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
||||
kwargs["model"] = config["model"]
|
||||
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
||||
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
||||
elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
||||
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
|
||||
conf_json = json.load(f)
|
||||
cfg = {}
|
||||
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
||||
cfg.update(kwargs)
|
||||
config = OmegaConf.load(cfg["config"])
|
||||
kwargs = OmegaConf.merge(config, cfg)
|
||||
kwargs["model"] = config["model"]
|
||||
return OmegaConf.to_container(kwargs, resolve=True)
|
||||
|
||||
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
|
||||
|
||||
if isinstance(file_path_metas, dict):
|
||||
for k, v in file_path_metas.items():
|
||||
if isinstance(v, str):
|
||||
p = os.path.join(model_or_path, v)
|
||||
if os.path.exists(p):
|
||||
cfg[k] = p
|
||||
elif isinstance(v, dict):
|
||||
if k not in cfg:
|
||||
cfg[k] = {}
|
||||
return add_file_root_path(model_or_path, v, cfg[k])
|
||||
|
||||
return cfg
|
||||
|
||||
if isinstance(file_path_metas, dict):
|
||||
for k, v in file_path_metas.items():
|
||||
if isinstance(v, str):
|
||||
p = os.path.join(model_or_path, v)
|
||||
if os.path.exists(p):
|
||||
cfg[k] = p
|
||||
elif isinstance(v, dict):
|
||||
if k not in cfg:
|
||||
cfg[k] = {}
|
||||
return add_file_root_path(model_or_path, v, cfg[k])
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def get_or_download_model_dir(
|
||||
model,
|
||||
model_revision=None,
|
||||
is_training=False,
|
||||
check_latest=True,
|
||||
):
|
||||
""" Get local model directory or download model if necessary.
|
||||
model,
|
||||
model_revision=None,
|
||||
is_training=False,
|
||||
check_latest=True,
|
||||
):
|
||||
""" Get local model directory or download model if necessary.
|
||||
|
||||
Args:
|
||||
model (str): model id or path to local model directory.
|
||||
model_revision (str, optional): model version number.
|
||||
:param is_training:
|
||||
"""
|
||||
from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
Args:
|
||||
model (str): model id or path to local model directory.
|
||||
model_revision (str, optional): model version number.
|
||||
:param is_training:
|
||||
"""
|
||||
from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.utils.constant import Invoke, ThirdParty
|
||||
|
||||
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
|
||||
|
||||
if os.path.exists(model) and check_latest:
|
||||
model_cache_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
try:
|
||||
check_local_model_is_latest(
|
||||
model_cache_dir,
|
||||
user_agent={
|
||||
Invoke.KEY: key,
|
||||
ThirdParty.KEY: "funasr"
|
||||
})
|
||||
except:
|
||||
print("could not check the latest version")
|
||||
else:
|
||||
model_cache_dir = snapshot_download(
|
||||
model,
|
||||
revision=model_revision,
|
||||
user_agent={
|
||||
Invoke.KEY: key,
|
||||
ThirdParty.KEY: "funasr"
|
||||
})
|
||||
return model_cache_dir
|
||||
from modelscope.utils.constant import Invoke, ThirdParty
|
||||
|
||||
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
|
||||
|
||||
if os.path.exists(model) and check_latest:
|
||||
model_cache_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
try:
|
||||
check_local_model_is_latest(
|
||||
model_cache_dir,
|
||||
user_agent={
|
||||
Invoke.KEY: key,
|
||||
ThirdParty.KEY: "funasr"
|
||||
})
|
||||
except:
|
||||
print("could not check the latest version")
|
||||
else:
|
||||
model_cache_dir = snapshot_download(
|
||||
model,
|
||||
revision=model_revision,
|
||||
user_agent={
|
||||
Invoke.KEY: key,
|
||||
ThirdParty.KEY: "funasr"
|
||||
})
|
||||
return model_cache_dir
|
||||
@ -1,45 +1,47 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from funasr.utils.types import str2bool
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model-name', type=str, required=True)
|
||||
parser.add_argument('--export-dir', type=str, required=True)
|
||||
parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
|
||||
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
|
||||
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
|
||||
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
|
||||
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
|
||||
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
|
||||
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_dir = args.model_name
|
||||
if not Path(args.model_name).exists():
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
|
||||
(model_dir)
|
||||
if args.export:
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if args.quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
if not os.path.exists(model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.bin.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=args.export_dir,
|
||||
onnx=True,
|
||||
device="cpu",
|
||||
quant=args.quantize,
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model-name', type=str, required=True)
|
||||
parser.add_argument('--export-dir', type=str, required=True)
|
||||
parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
|
||||
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
|
||||
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
|
||||
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
|
||||
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
|
||||
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
|
||||
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_dir = args.model_name
|
||||
if not Path(args.model_name).exists():
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
|
||||
(model_dir)
|
||||
if args.export:
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if args.quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
if not os.path.exists(model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.bin.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=args.export_dir,
|
||||
onnx=True,
|
||||
device="cpu",
|
||||
quant=args.quantize,
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@ -5,12 +5,12 @@ from funasr.register import tables
|
||||
|
||||
@tables.register("model_classes", "Branchformer")
|
||||
class Branchformer(Transformer):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -7,13 +7,13 @@ from funasr.register import tables
|
||||
|
||||
@tables.register("model_classes", "Conformer")
|
||||
class Conformer(Transformer):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -5,12 +5,12 @@ from funasr.register import tables
|
||||
|
||||
@tables.register("model_classes", "EBranchformer")
|
||||
class EBranchformer(Transformer):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -7,12 +7,12 @@ from funasr.register import tables
|
||||
|
||||
@tables.register("model_classes", "SANM")
|
||||
class SANM(Transformer):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -1,289 +1,287 @@
|
||||
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.scama.utils import sequence_mask
|
||||
|
||||
from funasr.models.scama.utils import sequence_mask
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class overlap_chunk():
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
chunk_size: tuple = (16,),
|
||||
stride: tuple = (10,),
|
||||
pad_left: tuple = (0,),
|
||||
encoder_att_look_back_factor: tuple = (1,),
|
||||
"""
|
||||
def __init__(self,
|
||||
chunk_size: tuple = (16,),
|
||||
stride: tuple = (10,),
|
||||
pad_left: tuple = (0,),
|
||||
encoder_att_look_back_factor: tuple = (1,),
|
||||
shfit_fsmn: int = 0,
|
||||
decoder_att_look_back_factor: tuple = (1,),
|
||||
):
|
||||
):
|
||||
|
||||
pad_left = self.check_chunk_size_args(chunk_size, pad_left)
|
||||
encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
|
||||
decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
|
||||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
|
||||
self.shfit_fsmn = shfit_fsmn
|
||||
self.x_add_mask = None
|
||||
self.x_rm_mask = None
|
||||
self.x_len = None
|
||||
self.mask_shfit_chunk = None
|
||||
self.mask_chunk_predictor = None
|
||||
self.mask_att_chunk_encoder = None
|
||||
self.mask_shift_att_chunk_decoder = None
|
||||
self.chunk_outs = None
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
|
||||
= None, None, None, None, None
|
||||
pad_left = self.check_chunk_size_args(chunk_size, pad_left)
|
||||
encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
|
||||
decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
|
||||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
|
||||
self.shfit_fsmn = shfit_fsmn
|
||||
self.x_add_mask = None
|
||||
self.x_rm_mask = None
|
||||
self.x_len = None
|
||||
self.mask_shfit_chunk = None
|
||||
self.mask_chunk_predictor = None
|
||||
self.mask_att_chunk_encoder = None
|
||||
self.mask_shift_att_chunk_decoder = None
|
||||
self.chunk_outs = None
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
|
||||
= None, None, None, None, None
|
||||
|
||||
def check_chunk_size_args(self, chunk_size, x):
|
||||
if len(x) < len(chunk_size):
|
||||
x = [x[0] for i in chunk_size]
|
||||
return x
|
||||
def check_chunk_size_args(self, chunk_size, x):
|
||||
if len(x) < len(chunk_size):
|
||||
x = [x[0] for i in chunk_size]
|
||||
return x
|
||||
|
||||
def get_chunk_size(self,
|
||||
ind: int = 0
|
||||
):
|
||||
# with torch.no_grad:
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
|
||||
self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
|
||||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
|
||||
def get_chunk_size(self,
|
||||
ind: int = 0
|
||||
):
|
||||
# with torch.no_grad:
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
|
||||
self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
|
||||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
|
||||
|
||||
def random_choice(self, training=True, decoding_ind=None):
|
||||
chunk_num = len(self.chunk_size)
|
||||
ind = 0
|
||||
if training and chunk_num > 1:
|
||||
ind = torch.randint(0, chunk_num, ()).cpu().item()
|
||||
if not training and decoding_ind is not None:
|
||||
ind = int(decoding_ind)
|
||||
def random_choice(self, training=True, decoding_ind=None):
|
||||
chunk_num = len(self.chunk_size)
|
||||
ind = 0
|
||||
if training and chunk_num > 1:
|
||||
ind = torch.randint(0, chunk_num, ()).cpu().item()
|
||||
if not training and decoding_ind is not None:
|
||||
ind = int(decoding_ind)
|
||||
|
||||
return ind
|
||||
return ind
|
||||
|
||||
|
||||
|
||||
|
||||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
|
||||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
|
||||
|
||||
with torch.no_grad():
|
||||
x_len = x_len.cpu().numpy()
|
||||
x_len_max = x_len.max()
|
||||
with torch.no_grad():
|
||||
x_len = x_len.cpu().numpy()
|
||||
x_len_max = x_len.max()
|
||||
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
|
||||
shfit_fsmn = self.shfit_fsmn
|
||||
pad_right = chunk_size - stride - pad_left
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
|
||||
shfit_fsmn = self.shfit_fsmn
|
||||
pad_right = chunk_size - stride - pad_left
|
||||
|
||||
chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
|
||||
x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
|
||||
x_len_chunk = x_len_chunk.astype(x_len.dtype)
|
||||
x_len_chunk_max = x_len_chunk.max()
|
||||
chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
|
||||
x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
|
||||
x_len_chunk = x_len_chunk.astype(x_len.dtype)
|
||||
x_len_chunk_max = x_len_chunk.max()
|
||||
|
||||
chunk_num = int(math.ceil(x_len_max/stride))
|
||||
dtype = np.int32
|
||||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
|
||||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
|
||||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
|
||||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
|
||||
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
|
||||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
|
||||
mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
for chunk_ids in range(chunk_num):
|
||||
# x_mask add
|
||||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
|
||||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
|
||||
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
|
||||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
|
||||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
|
||||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
|
||||
chunk_num = int(math.ceil(x_len_max/stride))
|
||||
dtype = np.int32
|
||||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
|
||||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
|
||||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
|
||||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
|
||||
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
|
||||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
|
||||
mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
for chunk_ids in range(chunk_num):
|
||||
# x_mask add
|
||||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
|
||||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
|
||||
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
|
||||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
|
||||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
|
||||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
|
||||
|
||||
# x_mask rm
|
||||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
|
||||
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
|
||||
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
|
||||
x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
|
||||
x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
|
||||
x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
|
||||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
|
||||
x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
|
||||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
|
||||
# x_mask rm
|
||||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
|
||||
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
|
||||
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
|
||||
x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
|
||||
x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
|
||||
x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
|
||||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
|
||||
x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
|
||||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
|
||||
|
||||
# fsmn_padding_mask
|
||||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
|
||||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
|
||||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
|
||||
mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
|
||||
# fsmn_padding_mask
|
||||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
|
||||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
|
||||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
|
||||
mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
|
||||
|
||||
# predictor mask
|
||||
zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
|
||||
zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
|
||||
mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
|
||||
mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
|
||||
# predictor mask
|
||||
zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
|
||||
zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
|
||||
mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
|
||||
mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
|
||||
|
||||
# encoder att mask
|
||||
zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
# encoder att mask
|
||||
zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
|
||||
zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
|
||||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
|
||||
zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
|
||||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_2_mid = np.ones([stride, stride], dtype=dtype)
|
||||
zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
|
||||
zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
|
||||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
|
||||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
|
||||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
|
||||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
|
||||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_2_mid = np.ones([stride, stride], dtype=dtype)
|
||||
zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
|
||||
zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
|
||||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
|
||||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
|
||||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
|
||||
|
||||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
|
||||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
|
||||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
|
||||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
|
||||
|
||||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
|
||||
zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
|
||||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
|
||||
zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
|
||||
|
||||
ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
|
||||
mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
|
||||
mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
|
||||
ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
|
||||
mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
|
||||
mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
|
||||
|
||||
|
||||
# decoder fsmn_shift_att_mask
|
||||
zeros_1 = np.zeros([shfit_fsmn, 1])
|
||||
ones_1 = np.ones([chunk_size, 1])
|
||||
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
|
||||
mask_shift_att_chunk_decoder = np.concatenate(
|
||||
[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
|
||||
# decoder fsmn_shift_att_mask
|
||||
zeros_1 = np.zeros([shfit_fsmn, 1])
|
||||
ones_1 = np.ones([chunk_size, 1])
|
||||
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
|
||||
mask_shift_att_chunk_decoder = np.concatenate(
|
||||
[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
|
||||
|
||||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
|
||||
self.x_len_chunk = x_len_chunk
|
||||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
|
||||
self.x_len = x_len
|
||||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
|
||||
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
|
||||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
|
||||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
|
||||
self.chunk_outs = (self.x_add_mask,
|
||||
self.x_len_chunk,
|
||||
self.x_rm_mask,
|
||||
self.x_len,
|
||||
self.mask_shfit_chunk,
|
||||
self.mask_chunk_predictor,
|
||||
self.mask_att_chunk_encoder,
|
||||
self.mask_shift_att_chunk_decoder)
|
||||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
|
||||
self.x_len_chunk = x_len_chunk
|
||||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
|
||||
self.x_len = x_len
|
||||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
|
||||
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
|
||||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
|
||||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
|
||||
self.chunk_outs = (self.x_add_mask,
|
||||
self.x_len_chunk,
|
||||
self.x_rm_mask,
|
||||
self.x_len,
|
||||
self.mask_shfit_chunk,
|
||||
self.mask_chunk_predictor,
|
||||
self.mask_att_chunk_encoder,
|
||||
self.mask_shift_att_chunk_decoder)
|
||||
|
||||
return self.chunk_outs
|
||||
return self.chunk_outs
|
||||
|
||||
|
||||
def split_chunk(self, x, x_len, chunk_outs):
|
||||
"""
|
||||
:param x: (b, t, d)
|
||||
:param x_length: (b)
|
||||
:param ind: int
|
||||
:return:
|
||||
"""
|
||||
x = x[:, :x_len.max(), :]
|
||||
b, t, d = x.size()
|
||||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
|
||||
x.device)
|
||||
x *= x_len_mask[:, :, None]
|
||||
def split_chunk(self, x, x_len, chunk_outs):
|
||||
"""
|
||||
:param x: (b, t, d)
|
||||
:param x_length: (b)
|
||||
:param ind: int
|
||||
:return:
|
||||
"""
|
||||
x = x[:, :x_len.max(), :]
|
||||
b, t, d = x.size()
|
||||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
|
||||
x.device)
|
||||
x *= x_len_mask[:, :, None]
|
||||
|
||||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
|
||||
x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
|
||||
pad = (0, 0, self.pad_left_cur, 0)
|
||||
x = F.pad(x, pad, "constant", 0.0)
|
||||
b, t, d = x.size()
|
||||
x = torch.transpose(x, 1, 0)
|
||||
x = torch.reshape(x, [t, -1])
|
||||
x_chunk = torch.mm(x_add_mask, x)
|
||||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
|
||||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
|
||||
x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
|
||||
pad = (0, 0, self.pad_left_cur, 0)
|
||||
x = F.pad(x, pad, "constant", 0.0)
|
||||
b, t, d = x.size()
|
||||
x = torch.transpose(x, 1, 0)
|
||||
x = torch.reshape(x, [t, -1])
|
||||
x_chunk = torch.mm(x_add_mask, x)
|
||||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x_chunk, x_len_chunk
|
||||
return x_chunk, x_len_chunk
|
||||
|
||||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
|
||||
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
|
||||
b, t, d = x_chunk.size()
|
||||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
|
||||
x_chunk.device)
|
||||
x_chunk *= x_len_chunk_mask[:, :, None]
|
||||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
|
||||
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
|
||||
b, t, d = x_chunk.size()
|
||||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
|
||||
x_chunk.device)
|
||||
x_chunk *= x_len_chunk_mask[:, :, None]
|
||||
|
||||
x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
|
||||
x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
|
||||
x_chunk = torch.transpose(x_chunk, 1, 0)
|
||||
x_chunk = torch.reshape(x_chunk, [t, -1])
|
||||
x = torch.mm(x_rm_mask, x_chunk)
|
||||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
|
||||
x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
|
||||
x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
|
||||
x_chunk = torch.transpose(x_chunk, 1, 0)
|
||||
x_chunk = torch.reshape(x_chunk, [t, -1])
|
||||
x = torch.mm(x_rm_mask, x_chunk)
|
||||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x, x_len
|
||||
return x, x_len
|
||||
|
||||
def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
|
||||
with torch.no_grad():
|
||||
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
|
||||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def build_scama_mask_for_cross_attention_decoder(
|
||||
predictor_alignments: torch.Tensor,
|
||||
predictor_alignments: torch.Tensor,
|
||||
encoder_sequence_length: torch.Tensor,
|
||||
chunk_size: int = 5,
|
||||
encoder_chunk_size: int = 5,
|
||||
@ -291,100 +289,100 @@ def build_scama_mask_for_cross_attention_decoder(
|
||||
attention_chunk_size: int = 1,
|
||||
attention_chunk_type: str = 'chunk',
|
||||
step=None,
|
||||
predictor_mask_chunk_hopping: torch.Tensor = None,
|
||||
decoder_att_look_back_factor: int = 1,
|
||||
mask_shift_att_chunk_decoder: torch.Tensor = None,
|
||||
target_length: torch.Tensor = None,
|
||||
is_training=True,
|
||||
predictor_mask_chunk_hopping: torch.Tensor = None,
|
||||
decoder_att_look_back_factor: int = 1,
|
||||
mask_shift_att_chunk_decoder: torch.Tensor = None,
|
||||
target_length: torch.Tensor = None,
|
||||
is_training=True,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
with torch.no_grad():
|
||||
device = predictor_alignments.device
|
||||
batch_size, chunk_num = predictor_alignments.size()
|
||||
maximum_encoder_length = encoder_sequence_length.max().item()
|
||||
int_type = predictor_alignments.dtype
|
||||
if not is_training:
|
||||
target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
|
||||
maximum_target_length = target_length.max()
|
||||
predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
|
||||
predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
|
||||
|
||||
|
||||
index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, chunk_num)
|
||||
|
||||
index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div == 0
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
|
||||
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
|
||||
|
||||
index_div_bool_zeros_count *= chunk_size
|
||||
index_div_bool_zeros_count += attention_chunk_center_bias
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
|
||||
index_div_bool_zeros_count_ori = index_div_bool_zeros_count
|
||||
|
||||
index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
|
||||
max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
|
||||
|
||||
mask_flip, mask_flip2 = None, None
|
||||
if attention_chunk_size is not None:
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip = 1 - index_div_bool_zeros_count_beg_mask
|
||||
attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
|
||||
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
|
||||
|
||||
mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
|
||||
|
||||
if predictor_mask_chunk_hopping is not None:
|
||||
b, k, t = mask.size()
|
||||
predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
|
||||
|
||||
mask_mask_flip = mask
|
||||
if mask_flip is not None:
|
||||
mask_mask_flip = mask_flip * mask
|
||||
|
||||
def _fn():
|
||||
mask_sliced = mask[:b, :k, encoder_chunk_size:t]
|
||||
zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
|
||||
mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
|
||||
_, _, tt = predictor_mask_chunk_hopping.size()
|
||||
pad_right_p = max_len_chunk - tt
|
||||
predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
|
||||
masked = mask_sliced * predictor_mask_chunk_hopping_pad
|
||||
|
||||
mask_true = mask_mask_flip + masked
|
||||
return mask_true
|
||||
|
||||
mask = _fn() if t > chunk_size else mask_mask_flip
|
||||
|
||||
|
||||
|
||||
if mask_flip2 is not None:
|
||||
mask *= mask_flip2
|
||||
|
||||
mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
|
||||
|
||||
|
||||
|
||||
mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
|
||||
|
||||
|
||||
|
||||
|
||||
if attention_chunk_type == 'full':
|
||||
mask = torch.ones_like(mask).to(device)
|
||||
if mask_shift_att_chunk_decoder is not None:
|
||||
mask = mask * mask_shift_att_chunk_decoder
|
||||
mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
|
||||
with torch.no_grad():
|
||||
device = predictor_alignments.device
|
||||
batch_size, chunk_num = predictor_alignments.size()
|
||||
maximum_encoder_length = encoder_sequence_length.max().item()
|
||||
int_type = predictor_alignments.dtype
|
||||
if not is_training:
|
||||
target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
|
||||
maximum_target_length = target_length.max()
|
||||
predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
|
||||
predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
|
||||
|
||||
|
||||
index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = index[:, :, None].repeat(1, 1, chunk_num)
|
||||
|
||||
index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
|
||||
index_div_bool_zeros = index_div == 0
|
||||
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
|
||||
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
|
||||
|
||||
index_div_bool_zeros_count *= chunk_size
|
||||
index_div_bool_zeros_count += attention_chunk_center_bias
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
|
||||
index_div_bool_zeros_count_ori = index_div_bool_zeros_count
|
||||
|
||||
index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
|
||||
max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
|
||||
|
||||
mask_flip, mask_flip2 = None, None
|
||||
if attention_chunk_size is not None:
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip = 1 - index_div_bool_zeros_count_beg_mask
|
||||
attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
|
||||
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
|
||||
|
||||
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
|
||||
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
|
||||
mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
|
||||
|
||||
mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
|
||||
|
||||
if predictor_mask_chunk_hopping is not None:
|
||||
b, k, t = mask.size()
|
||||
predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
|
||||
|
||||
mask_mask_flip = mask
|
||||
if mask_flip is not None:
|
||||
mask_mask_flip = mask_flip * mask
|
||||
|
||||
def _fn():
|
||||
mask_sliced = mask[:b, :k, encoder_chunk_size:t]
|
||||
zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
|
||||
mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
|
||||
_, _, tt = predictor_mask_chunk_hopping.size()
|
||||
pad_right_p = max_len_chunk - tt
|
||||
predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
|
||||
masked = mask_sliced * predictor_mask_chunk_hopping_pad
|
||||
|
||||
mask_true = mask_mask_flip + masked
|
||||
return mask_true
|
||||
|
||||
mask = _fn() if t > chunk_size else mask_mask_flip
|
||||
|
||||
|
||||
|
||||
if mask_flip2 is not None:
|
||||
mask *= mask_flip2
|
||||
|
||||
mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
|
||||
|
||||
|
||||
|
||||
mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
|
||||
mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
|
||||
|
||||
|
||||
|
||||
|
||||
if attention_chunk_type == 'full':
|
||||
mask = torch.ones_like(mask).to(device)
|
||||
if mask_shift_att_chunk_decoder is not None:
|
||||
mask = mask * mask_shift_att_chunk_decoder
|
||||
mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
|
||||
|
||||
return mask
|
||||
return mask
|
||||
|
||||
|
||||
@ -1,29 +1,30 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
def apply_cmvn(inputs, mvn):
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
meams = np.tile(mvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(mvn[1:2, :dim], (frame, 1))
|
||||
inputs -= torch.from_numpy(meams).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
meams = np.tile(mvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(mvn[1:2, :dim], (frame, 1))
|
||||
inputs -= torch.from_numpy(meams).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
|
||||
@ -36,56 +37,56 @@ def drop_and_add(inputs: torch.Tensor,
|
||||
|
||||
|
||||
|
||||
outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
|
||||
outputs *= stoch_layer_coeff
|
||||
outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
|
||||
outputs *= stoch_layer_coeff
|
||||
|
||||
input_dim = inputs.size(-1)
|
||||
output_dim = outputs.size(-1)
|
||||
input_dim = inputs.size(-1)
|
||||
output_dim = outputs.size(-1)
|
||||
|
||||
if input_dim == output_dim:
|
||||
outputs += inputs
|
||||
return outputs
|
||||
if input_dim == output_dim:
|
||||
outputs += inputs
|
||||
return outputs
|
||||
|
||||
|
||||
def proc_tf_vocab(vocab_path):
|
||||
with open(vocab_path, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
if '<unk>' not in token_list:
|
||||
token_list.append('<unk>')
|
||||
return token_list
|
||||
with open(vocab_path, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
if '<unk>' not in token_list:
|
||||
token_list.append('<unk>')
|
||||
return token_list
|
||||
|
||||
|
||||
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
|
||||
token_list = proc_tf_vocab(vocab_path)
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config['token_list'] = token_list
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
|
||||
token_list = proc_tf_vocab(vocab_path)
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
config['token_list'] = token_list
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
|
||||
|
||||
|
||||
class NoAliasSafeDumper(yaml.SafeDumper):
|
||||
# Disable anchor/alias in yaml because looks ugly
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
# Disable anchor/alias in yaml because looks ugly
|
||||
def ignore_aliases(self, data):
|
||||
return True
|
||||
|
||||
|
||||
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
|
||||
"""Safe-dump in yaml with no anchor/alias"""
|
||||
return yaml.dump(
|
||||
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
|
||||
)
|
||||
"""Safe-dump in yaml with no anchor/alias"""
|
||||
return yaml.dump(
|
||||
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
vocab_path = sys.argv[2]
|
||||
output_dir = sys.argv[3]
|
||||
gen_config_for_tfmodel(config_path, vocab_path, output_dir)
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
vocab_path = sys.argv[2]
|
||||
output_dir = sys.argv[3]
|
||||
gen_config_for_tfmodel(config_path, vocab_path, output_dir)
|
||||
@ -541,20 +541,20 @@ class UniASR(FunASRModel):
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
# with autocast(False):
|
||||
# # 1. Extract feats
|
||||
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
# # 1. Extract feats
|
||||
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
#
|
||||
# # 2. Data augmentation
|
||||
# if self.specaug is not None and self.training:
|
||||
# feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
# # 2. Data augmentation
|
||||
# if self.specaug is not None and self.training:
|
||||
# feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
#
|
||||
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
# if self.normalize is not None:
|
||||
# feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
# if self.normalize is not None:
|
||||
# feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
# if self.preencoder is not None:
|
||||
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
@ -584,9 +584,9 @@ class UniASR(FunASRModel):
|
||||
|
||||
# # Post-encoder, e.g. NLU
|
||||
# if self.postencoder is not None:
|
||||
# encoder_out, encoder_out_lens = self.postencoder(
|
||||
# encoder_out, encoder_out_lens
|
||||
# )
|
||||
# encoder_out, encoder_out_lens = self.postencoder(
|
||||
# encoder_out, encoder_out_lens
|
||||
# )
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
|
||||
@ -3,15 +3,15 @@ from funasr.optimizers.fairseq_adam import FairseqAdam
|
||||
from funasr.optimizers.sgd import SGD
|
||||
|
||||
optim_classes = dict(
|
||||
adam=torch.optim.Adam,
|
||||
fairseq_adam=FairseqAdam,
|
||||
adamw=torch.optim.AdamW,
|
||||
sgd=SGD,
|
||||
adadelta=torch.optim.Adadelta,
|
||||
adagrad=torch.optim.Adagrad,
|
||||
adamax=torch.optim.Adamax,
|
||||
asgd=torch.optim.ASGD,
|
||||
lbfgs=torch.optim.LBFGS,
|
||||
rmsprop=torch.optim.RMSprop,
|
||||
rprop=torch.optim.Rprop,
|
||||
adam=torch.optim.Adam,
|
||||
fairseq_adam=FairseqAdam,
|
||||
adamw=torch.optim.AdamW,
|
||||
sgd=SGD,
|
||||
adadelta=torch.optim.Adadelta,
|
||||
adagrad=torch.optim.Adagrad,
|
||||
adamax=torch.optim.Adamax,
|
||||
asgd=torch.optim.ASGD,
|
||||
lbfgs=torch.optim.LBFGS,
|
||||
rmsprop=torch.optim.RMSprop,
|
||||
rprop=torch.optim.Rprop,
|
||||
)
|
||||
@ -8,16 +8,16 @@ from funasr.schedulers.tri_stage_scheduler import TriStageLR
|
||||
from funasr.schedulers.warmup_lr import WarmupLR
|
||||
|
||||
scheduler_classes = dict(
|
||||
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
|
||||
lambdalr=torch.optim.lr_scheduler.LambdaLR,
|
||||
steplr=torch.optim.lr_scheduler.StepLR,
|
||||
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
|
||||
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
|
||||
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
|
||||
noamlr=NoamLR,
|
||||
warmuplr=WarmupLR,
|
||||
tri_stage=TriStageLR,
|
||||
cycliclr=torch.optim.lr_scheduler.CyclicLR,
|
||||
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
|
||||
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
||||
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
|
||||
lambdalr=torch.optim.lr_scheduler.LambdaLR,
|
||||
steplr=torch.optim.lr_scheduler.StepLR,
|
||||
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
|
||||
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
|
||||
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
|
||||
noamlr=NoamLR,
|
||||
warmuplr=WarmupLR,
|
||||
tri_stage=TriStageLR,
|
||||
cycliclr=torch.optim.lr_scheduler.CyclicLR,
|
||||
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
|
||||
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
||||
)
|
||||
|
||||
@ -1,100 +1,94 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from abc import abstractmethod
|
||||
from typing import Union, Iterable, List, Dict
|
||||
|
||||
|
||||
class AbsTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if token_list is not None:
|
||||
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with open(token_list, 'r', encoding='utf-8') as f:
|
||||
self.token_list = json.load(f)
|
||||
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.text2tokens(text)
|
||||
text_ints = self.tokens2ids(tokens)
|
||||
|
||||
return text_ints
|
||||
|
||||
def decode(self, text_ints):
|
||||
token = self.ids2tokens(text_ints)
|
||||
text = self.tokens2text(token)
|
||||
return text
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if token_list is not None:
|
||||
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with open(token_list, 'r', encoding='utf-8') as f:
|
||||
self.token_list = json.load(f)
|
||||
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.text2tokens(text)
|
||||
text_ints = self.tokens2ids(tokens)
|
||||
|
||||
return text_ints
|
||||
|
||||
def decode(self, text_ints):
|
||||
token = self.ids2tokens(text_ints)
|
||||
text = self.tokens2text(token)
|
||||
return text
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
@ -1,233 +1,235 @@
|
||||
import torch
|
||||
import os
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from contextlib import nullcontext
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from funasr.train_utils.device_funcs import to_device
|
||||
from funasr.train_utils.recursive_op import recursive_average
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
and optionally resuming from a saved checkpoint.
|
||||
"""
|
||||
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
|
||||
and optionally resuming from a saved checkpoint.
|
||||
|
||||
Attributes:
|
||||
max_epoch (int): Maximum number of epochs for training.
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
||||
output_dir (str): Directory where model checkpoints will be saved.
|
||||
resume (str, optional): Path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
def __init__(self, model,
|
||||
optim,
|
||||
scheduler,
|
||||
dataloader_train,
|
||||
dataloader_val,
|
||||
local_rank,
|
||||
use_ddp=False,
|
||||
use_fsdp=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
||||
Attributes:
|
||||
max_epoch (int): Maximum number of epochs for training.
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
|
||||
output_dir (str): Directory where model checkpoints will be saved.
|
||||
resume (str, optional): Path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
def __init__(self, model,
|
||||
optim,
|
||||
scheduler,
|
||||
dataloader_train,
|
||||
dataloader_val,
|
||||
local_rank,
|
||||
use_ddp=False,
|
||||
use_fsdp=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
|
||||
**kwargs: Additional keyword arguments:
|
||||
max_epoch (int): The maximum number of epochs for training.
|
||||
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
||||
resume (str, optional): The file path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.scheduler = scheduler
|
||||
self.dataloader_train = dataloader_train
|
||||
self.dataloader_val = dataloader_val
|
||||
self.output_dir = kwargs.get('output_dir', './')
|
||||
self.resume = kwargs.get('resume', True)
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = kwargs.get('max_epoch', 100)
|
||||
self.local_rank = local_rank
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
self.device = next(model.parameters()).device
|
||||
self.kwargs = kwargs
|
||||
|
||||
if self.resume:
|
||||
self._resume_checkpoint(self.resume)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
logging.warning("distributed is not initialized, only single shard")
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def _save_checkpoint(self, epoch):
|
||||
"""
|
||||
Saves a checkpoint containing the model's state, the optimizer's state,
|
||||
and the scheduler's state at the end of the given epoch. This method is
|
||||
intended to be called at the end of each epoch to save the training progress.
|
||||
Args:
|
||||
model (torch.nn.Module): The model to be trained.
|
||||
optim (torch.optim.Optimizer): The optimizer to use for training.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
|
||||
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
|
||||
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
|
||||
**kwargs: Additional keyword arguments:
|
||||
max_epoch (int): The maximum number of epochs for training.
|
||||
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
|
||||
resume (str, optional): The file path to a checkpoint to resume training from.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.scheduler = scheduler
|
||||
self.dataloader_train = dataloader_train
|
||||
self.dataloader_val = dataloader_val
|
||||
self.output_dir = kwargs.get('output_dir', './')
|
||||
self.resume = kwargs.get('resume', True)
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = kwargs.get('max_epoch', 100)
|
||||
self.local_rank = local_rank
|
||||
self.use_ddp = use_ddp
|
||||
self.use_fsdp = use_fsdp
|
||||
self.device = next(model.parameters()).device
|
||||
self.kwargs = kwargs
|
||||
|
||||
if self.resume:
|
||||
self._resume_checkpoint(self.resume)
|
||||
|
||||
try:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
except:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
logging.warning("distributed is not initialized, only single shard")
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def _save_checkpoint(self, epoch):
|
||||
"""
|
||||
Saves a checkpoint containing the model's state, the optimizer's state,
|
||||
and the scheduler's state at the end of the given epoch. This method is
|
||||
intended to be called at the end of each epoch to save the training progress.
|
||||
|
||||
Args:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
state = {
|
||||
'epoch': epoch,
|
||||
'state_dict': self.model.state_dict(),
|
||||
'optimizer': self.optim.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict(),
|
||||
}
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
|
||||
torch.save(state, filename)
|
||||
print(f'Checkpoint saved to {filename}')
|
||||
|
||||
def _resume_checkpoint(self, resume_path):
|
||||
"""
|
||||
Resumes training from a checkpoint at the given file path.
|
||||
Loads the model's state, the optimizer's state, and the scheduler's state.
|
||||
Args:
|
||||
epoch (int): The epoch number at which the checkpoint is being saved.
|
||||
"""
|
||||
state = {
|
||||
'epoch': epoch,
|
||||
'state_dict': self.model.state_dict(),
|
||||
'optimizer': self.optim.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict(),
|
||||
}
|
||||
# Create output directory if it does not exist
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
|
||||
torch.save(state, filename)
|
||||
print(f'Checkpoint saved to {filename}')
|
||||
|
||||
def _resume_checkpoint(self, resume_path):
|
||||
"""
|
||||
Resumes training from a checkpoint at the given file path.
|
||||
Loads the model's state, the optimizer's state, and the scheduler's state.
|
||||
|
||||
Args:
|
||||
resume_path (str): The file path to the checkpoint to resume from.
|
||||
"""
|
||||
if os.path.isfile(resume_path):
|
||||
checkpoint = torch.load(resume_path)
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
self.optim.load_state_dict(checkpoint['optimizer'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
|
||||
else:
|
||||
print(f"No checkpoint found at '{resume_path}', starting from scratch")
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Starts the training process, iterating over epochs, training the model,
|
||||
and saving checkpoints at the end of each epoch.
|
||||
"""
|
||||
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||
self._train_epoch(epoch)
|
||||
# self._validate_epoch(epoch)
|
||||
if self.rank == 0:
|
||||
self._save_checkpoint(epoch)
|
||||
self.scheduler.step()
|
||||
|
||||
def _train_epoch(self, epoch):
|
||||
"""
|
||||
Defines the training process for a single epoch with gradient accumulation.
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
self.model.train()
|
||||
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
|
||||
dynamic_ncols=True)
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
accum_grad = self.kwargs.get("accum_grad", 1)
|
||||
# Initialize the gradient accumulation
|
||||
self.optim.zero_grad()
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
for batch_idx, batch in enumerate(self.dataloader_train):
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
retval = self.model(**batch)
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accum_grad
|
||||
loss.backward()
|
||||
time4 = time.perf_counter()
|
||||
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
|
||||
|
||||
# Perform an optimizer step only after accumulating enough gradients
|
||||
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
|
||||
# Perform gradient clipping if it is set
|
||||
if self.kwargs.get("grad_clip", None) is not None:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
max_norm=self.kwargs.get("grad_clip", 10.0),
|
||||
norm_type=self.kwargs.get("grad_clip_type", 2.0),
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
self.optim.zero_grad() # Reset gradients
|
||||
continue
|
||||
|
||||
# Execute an optimization step (update model parameters)
|
||||
self.optim.step()
|
||||
self.scheduler.step()
|
||||
# Clear gradients for the next accumulation stage
|
||||
self.optim.zero_grad()
|
||||
total_time = f"{time.perf_counter() - time5:0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
speed_stats["total_time"] = total_time
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
pbar.update(1)
|
||||
if self.local_rank == 0:
|
||||
description = (
|
||||
f"Epoch: {epoch + 1}/{self.max_epoch}, "
|
||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||
f"{speed_stats}, "
|
||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
|
||||
)
|
||||
pbar.set_description(description)
|
||||
|
||||
# if batch_idx == 2:
|
||||
# break
|
||||
pbar.close()
|
||||
Args:
|
||||
resume_path (str): The file path to the checkpoint to resume from.
|
||||
"""
|
||||
if os.path.isfile(resume_path):
|
||||
checkpoint = torch.load(resume_path)
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
self.optim.load_state_dict(checkpoint['optimizer'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
|
||||
else:
|
||||
print(f"No checkpoint found at '{resume_path}', starting from scratch")
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Starts the training process, iterating over epochs, training the model,
|
||||
and saving checkpoints at the end of each epoch.
|
||||
"""
|
||||
for epoch in range(self.start_epoch, self.max_epoch + 1):
|
||||
self._train_epoch(epoch)
|
||||
# self._validate_epoch(epoch)
|
||||
if self.rank == 0:
|
||||
self._save_checkpoint(epoch)
|
||||
self.scheduler.step()
|
||||
|
||||
def _train_epoch(self, epoch):
|
||||
"""
|
||||
Defines the training process for a single epoch with gradient accumulation.
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
self.model.train()
|
||||
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
|
||||
dynamic_ncols=True)
|
||||
|
||||
# Set the number of steps for gradient accumulation
|
||||
accum_grad = self.kwargs.get("accum_grad", 1)
|
||||
# Initialize the gradient accumulation
|
||||
self.optim.zero_grad()
|
||||
speed_stats = {}
|
||||
time5 = time.perf_counter()
|
||||
for batch_idx, batch in enumerate(self.dataloader_train):
|
||||
time1 = time.perf_counter()
|
||||
speed_stats["data_load"] = f"{time1-time5:0.3f}"
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
batch = to_device(batch, self.device)
|
||||
|
||||
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
|
||||
with my_context():
|
||||
time2 = time.perf_counter()
|
||||
retval = self.model(**batch)
|
||||
time3 = time.perf_counter()
|
||||
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
|
||||
loss, stats, weight = retval
|
||||
stats = {k: v for k, v in stats.items() if v is not None}
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
# Apply weighted averaging for loss and stats
|
||||
loss = (loss * weight.type(loss.dtype)).sum()
|
||||
# if distributed, this method can also apply all_reduce()
|
||||
stats, weight = recursive_average(stats, weight, distributed=True)
|
||||
# Now weight is summation over all workers
|
||||
loss /= weight
|
||||
# Multiply world_size because DistributedDataParallel
|
||||
# automatically normalizes the gradient by world_size.
|
||||
loss *= self.world_size
|
||||
# Scale the loss since we're not updating for every mini-batch
|
||||
loss = loss / accum_grad
|
||||
loss.backward()
|
||||
time4 = time.perf_counter()
|
||||
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
|
||||
|
||||
# Perform an optimizer step only after accumulating enough gradients
|
||||
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
|
||||
# Perform gradient clipping if it is set
|
||||
if self.kwargs.get("grad_clip", None) is not None:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
max_norm=self.kwargs.get("grad_clip", 10.0),
|
||||
norm_type=self.kwargs.get("grad_clip_type", 2.0),
|
||||
)
|
||||
if not torch.isfinite(grad_norm):
|
||||
logging.warning(
|
||||
f"The grad norm is {grad_norm}. Skipping updating the model."
|
||||
)
|
||||
self.optim.zero_grad() # Reset gradients
|
||||
continue
|
||||
|
||||
# Execute an optimization step (update model parameters)
|
||||
self.optim.step()
|
||||
self.scheduler.step()
|
||||
# Clear gradients for the next accumulation stage
|
||||
self.optim.zero_grad()
|
||||
total_time = f"{time.perf_counter() - time5:0.3f}"
|
||||
time5 = time.perf_counter()
|
||||
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
|
||||
|
||||
speed_stats["total_time"] = total_time
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
pbar.update(1)
|
||||
if self.local_rank == 0:
|
||||
description = (
|
||||
f"Epoch: {epoch + 1}/{self.max_epoch}, "
|
||||
f"step {batch_idx}/{len(self.dataloader_train)}, "
|
||||
f"{speed_stats}, "
|
||||
f"(loss: {loss.detach().cpu().item():.3f}), "
|
||||
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
|
||||
)
|
||||
pbar.set_description(description)
|
||||
|
||||
# if batch_idx == 2:
|
||||
# break
|
||||
pbar.close()
|
||||
|
||||
def _validate_epoch(self, epoch):
|
||||
"""
|
||||
Defines the validation process for a single epoch.
|
||||
Should be implemented with the actual model validation steps.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for data, target in self.dataloader_val:
|
||||
# Implement the model validation steps here
|
||||
pass
|
||||
def _validate_epoch(self, epoch):
|
||||
"""
|
||||
Defines the validation process for a single epoch.
|
||||
Should be implemented with the actual model validation steps.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch number.
|
||||
"""
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for data, target in self.dataloader_val:
|
||||
# Implement the model validation steps here
|
||||
pass
|
||||
|
||||
@ -10,100 +10,100 @@ import time
|
||||
import logging
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
try:
|
||||
from funasr.download.file import download_from_url
|
||||
from funasr.download.file import download_from_url
|
||||
except:
|
||||
print("urllib is not installed, if you infer from url, please install it first.")
|
||||
print("urllib is not installed, if you infer from url, please install it first.")
|
||||
|
||||
|
||||
|
||||
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
|
||||
if isinstance(data_or_path_or_list, (list, tuple)):
|
||||
if data_type is not None and isinstance(data_type, (list, tuple)):
|
||||
if isinstance(data_or_path_or_list, (list, tuple)):
|
||||
if data_type is not None and isinstance(data_type, (list, tuple)):
|
||||
|
||||
data_types = [data_type] * len(data_or_path_or_list)
|
||||
data_or_path_or_list_ret = [[] for d in data_type]
|
||||
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
|
||||
|
||||
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
|
||||
|
||||
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
|
||||
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
|
||||
data_types = [data_type] * len(data_or_path_or_list)
|
||||
data_or_path_or_list_ret = [[] for d in data_type]
|
||||
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
|
||||
|
||||
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
|
||||
|
||||
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
|
||||
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
|
||||
|
||||
return data_or_path_or_list_ret
|
||||
else:
|
||||
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
|
||||
data_or_path_or_list = download_from_url(data_or_path_or_list)
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
|
||||
if data_type is None or data_type == "sound":
|
||||
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
|
||||
data_or_path_or_list = data_or_path_or_list[0, :]
|
||||
elif data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif data_type == "image": # undo
|
||||
pass
|
||||
elif data_type == "video": # undo
|
||||
pass
|
||||
|
||||
# if data_in is a file or url, set is_final=True
|
||||
if "cache" in kwargs:
|
||||
kwargs["cache"]["is_final"] = True
|
||||
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
|
||||
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
|
||||
else:
|
||||
pass
|
||||
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
|
||||
|
||||
if audio_fs != fs and data_type != "text":
|
||||
resampler = torchaudio.transforms.Resample(audio_fs, fs)
|
||||
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
|
||||
return data_or_path_or_list
|
||||
return data_or_path_or_list_ret
|
||||
else:
|
||||
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
|
||||
data_or_path_or_list = download_from_url(data_or_path_or_list)
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
|
||||
if data_type is None or data_type == "sound":
|
||||
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
|
||||
data_or_path_or_list = data_or_path_or_list[0, :]
|
||||
elif data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif data_type == "image": # undo
|
||||
pass
|
||||
elif data_type == "video": # undo
|
||||
pass
|
||||
|
||||
# if data_in is a file or url, set is_final=True
|
||||
if "cache" in kwargs:
|
||||
kwargs["cache"]["is_final"] = True
|
||||
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
|
||||
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
|
||||
else:
|
||||
pass
|
||||
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
|
||||
|
||||
if audio_fs != fs and data_type != "text":
|
||||
resampler = torchaudio.transforms.Resample(audio_fs, fs)
|
||||
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
|
||||
return data_or_path_or_list
|
||||
|
||||
def load_bytes(input):
|
||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2 ** (i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2 ** (i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
|
||||
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if len(data.shape) < 2:
|
||||
data = data[None, :] # data: [batch, N]
|
||||
data_len = [data.shape[1]] if data_len is None else data_len
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if len(data.shape) < 2:
|
||||
data = data[None, :] # data: [batch, N]
|
||||
data_len = [data.shape[1]] if data_len is None else data_len
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_list, data_len = [], []
|
||||
for data_i in data:
|
||||
if isinstance(data_i, np.ndarray):
|
||||
data_i = torch.from_numpy(data_i)
|
||||
data_list.append(data_i)
|
||||
data_len.append(data_i.shape[0])
|
||||
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# if data_type == "sound":
|
||||
data, data_len = frontend(data, data_len, **kwargs)
|
||||
|
||||
if isinstance(data_len, (list, tuple)):
|
||||
data_len = torch.tensor([data_len])
|
||||
return data.to(torch.float32), data_len.to(torch.int32)
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if len(data.shape) < 2:
|
||||
data = data[None, :] # data: [batch, N]
|
||||
data_len = [data.shape[1]] if data_len is None else data_len
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if len(data.shape) < 2:
|
||||
data = data[None, :] # data: [batch, N]
|
||||
data_len = [data.shape[1]] if data_len is None else data_len
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_list, data_len = [], []
|
||||
for data_i in data:
|
||||
if isinstance(data_i, np.ndarray):
|
||||
data_i = torch.from_numpy(data_i)
|
||||
data_list.append(data_i)
|
||||
data_len.append(data_i.shape[0])
|
||||
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
# if data_type == "sound":
|
||||
data, data_len = frontend(data, data_len, **kwargs)
|
||||
|
||||
if isinstance(data_len, (list, tuple)):
|
||||
data_len = torch.tensor([data_len])
|
||||
return data.to(torch.float32), data_len.to(torch.int32)
|
||||
|
||||
|
||||
@ -1,31 +1,31 @@
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
|
||||
bed_idx = int(segment[0][0]*16)
|
||||
end_idx = min(int(segment[0][1]*16), speech_lengths[0])
|
||||
speech_i = speech[0, bed_idx: end_idx]
|
||||
speech_lengths_i = end_idx-bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
||||
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
||||
return feats_pad, speech_lengths_pad
|
||||
|
||||
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
|
||||
bed_idx = int(segment[0][0]*16)
|
||||
end_idx = min(int(segment[0][1]*16), speech_lengths[0])
|
||||
speech_i = speech[0, bed_idx: end_idx]
|
||||
speech_lengths_i = end_idx-bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
||||
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
||||
return feats_pad, speech_lengths_pad
|
||||
|
||||
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
bed_idx = int(segment[0][0] * 16)
|
||||
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
||||
speech_i = speech[bed_idx: end_idx]
|
||||
speech_lengths_i = end_idx - bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
|
||||
return speech_list, speech_lengths_list
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
bed_idx = int(segment[0][0] * 16)
|
||||
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
||||
speech_i = speech[bed_idx: end_idx]
|
||||
speech_lengths_i = end_idx - bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
|
||||
return speech_list, speech_lengths_list
|
||||
@ -17,8 +17,8 @@ args = parser.parse_args()
|
||||
|
||||
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
|
||||
if args.backend == "onnx":
|
||||
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
|
||||
|
||||
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
|
||||
|
||||
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
|
||||
|
||||
wav_file_f = open(args.wav_file, 'r')
|
||||
@ -26,23 +26,23 @@ wav_files = wav_file_f.readlines()
|
||||
|
||||
output_dir = args.output_dir
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
os.makedirs(output_dir)
|
||||
if os.name == 'nt': # Windows
|
||||
newline = '\r\n'
|
||||
newline = '\r\n'
|
||||
else: # Linux Mac
|
||||
newline = '\n'
|
||||
newline = '\n'
|
||||
text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
|
||||
token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
|
||||
|
||||
for i, wav_path_i in enumerate(wav_files):
|
||||
wav_name, wav_path = wav_path_i.strip().split()
|
||||
result = model(wav_path)
|
||||
text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
|
||||
token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
|
||||
text_f.write(text_i)
|
||||
text_f.flush()
|
||||
token_f.write(token_i)
|
||||
token_f.flush()
|
||||
wav_name, wav_path = wav_path_i.strip().split()
|
||||
result = model(wav_path)
|
||||
text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
|
||||
token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
|
||||
text_f.write(text_i)
|
||||
text_f.flush()
|
||||
token_f.write(token_i)
|
||||
token_f.flush()
|
||||
text_f.close()
|
||||
token_f.close()
|
||||
|
||||
|
||||
|
||||
@ -16,8 +16,8 @@ args = parser.parse_args()
|
||||
|
||||
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
|
||||
if args.backend == "onnx":
|
||||
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
|
||||
|
||||
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
|
||||
|
||||
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
|
||||
|
||||
wav_file_f = open(args.wav_file, 'r')
|
||||
@ -28,28 +28,28 @@ total = 0.0
|
||||
num = 30
|
||||
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
|
||||
for i in range(num):
|
||||
beg_time = time.time()
|
||||
result = model(wav_path)
|
||||
end_time = time.time()
|
||||
duration = end_time-beg_time
|
||||
total += duration
|
||||
print(result)
|
||||
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
|
||||
beg_time = time.time()
|
||||
result = model(wav_path)
|
||||
end_time = time.time()
|
||||
duration = end_time-beg_time
|
||||
total += duration
|
||||
print(result)
|
||||
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
|
||||
|
||||
# infer time
|
||||
beg_time = time.time()
|
||||
for i, wav_path_i in enumerate(wav_files):
|
||||
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
result = model(wav_path)
|
||||
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
result = model(wav_path)
|
||||
end_time = time.time()
|
||||
duration = (end_time-beg_time)*1000
|
||||
print("total_time_comput_ms: {}".format(int(duration)))
|
||||
|
||||
duration_time = 0.0
|
||||
for i, wav_path_i in enumerate(wav_files):
|
||||
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
waveform, _ = librosa.load(wav_path, sr=16000)
|
||||
duration_time += len(waveform)/16.0
|
||||
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
waveform, _ = librosa.load(wav_path, sr=16000)
|
||||
duration_time += len(waveform)/16.0
|
||||
print("total_time_wav_ms: {}".format(int(duration_time)))
|
||||
|
||||
print("total_rtf: {:.5}".format(duration/duration_time))
|
||||
@ -17,8 +17,8 @@ args = parser.parse_args()
|
||||
|
||||
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
|
||||
if args.backend == "onnx":
|
||||
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
|
||||
|
||||
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
|
||||
|
||||
model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
|
||||
|
||||
wav_file_f = open(args.wav_file, 'r')
|
||||
@ -29,20 +29,20 @@ total = 0.0
|
||||
num = 30
|
||||
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
|
||||
for i in range(num):
|
||||
beg_time = time.time()
|
||||
result = model(wav_path)
|
||||
end_time = time.time()
|
||||
duration = end_time-beg_time
|
||||
total += duration
|
||||
print(result)
|
||||
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
|
||||
beg_time = time.time()
|
||||
result = model(wav_path)
|
||||
end_time = time.time()
|
||||
duration = end_time-beg_time
|
||||
total += duration
|
||||
print(result)
|
||||
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
|
||||
|
||||
# infer time
|
||||
wav_path = []
|
||||
beg_time = time.time()
|
||||
for i, wav_path_i in enumerate(wav_files):
|
||||
wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
wav_path += [wav_path_i]
|
||||
wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
wav_path += [wav_path_i]
|
||||
result = model(wav_path)
|
||||
end_time = time.time()
|
||||
duration = (end_time-beg_time)*1000
|
||||
@ -50,9 +50,9 @@ print("total_time_comput_ms: {}".format(int(duration)))
|
||||
|
||||
duration_time = 0.0
|
||||
for i, wav_path_i in enumerate(wav_files):
|
||||
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
waveform, _ = librosa.load(wav_path, sr=16000)
|
||||
duration_time += len(waveform)/16.0
|
||||
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
|
||||
waveform, _ = librosa.load(wav_path, sr=16000)
|
||||
duration_time += len(waveform)/16.0
|
||||
print("total_time_wav_ms: {}".format(int(duration_time)))
|
||||
|
||||
print("total_rtf: {:.5}".format(duration/duration_time))
|
||||
Loading…
Reference in New Issue
Block a user