code update

This commit is contained in:
shixian.shi 2024-01-15 20:34:47 +08:00
parent 3fcb5dcfed
commit 1233c0d3ff
24 changed files with 1391 additions and 1404 deletions

View File

@ -18,5 +18,5 @@ frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-co
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
res = model(**fbank_dict)
print(res)
res = model(**fbank_dict)
print(res)

View File

@ -309,10 +309,7 @@ class AutoModel:
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
# if kwargs["device"] == "cpu":
# batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])

View File

@ -1,178 +1,180 @@
import argparse
import logging
import os
import sys
from io import BytesIO
from collections.abc import Sequence
import torch
import hydra
import logging
import argparse
from io import BytesIO
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.models.lora.utils import mark_only_lora_as_trainable
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.train_utils.initialize import initialize
from funasr.download.download_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
# from funasr.tokenizer.funtoken import build_tokenizer
from funasr.train_utils.trainer import Trainer
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.download.download_from_hub import download_model
from funasr.register import tables
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
main(**kwargs)
main(**kwargs)
def main(**kwargs):
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", None)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
# save config.yaml
if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
logging.info("config.yaml is saved to: %s", yaml_file)
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", None)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
# save config.yaml
if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
logging.info("config.yaml is saved to: %s", yaml_file)
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
# build frontend if frontend is none None
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# import pdb;
# pdb.set_trace()
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
# build frontend if frontend is none None
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# import pdb;
# pdb.set_trace()
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
if not isinstance(init_param, (list, tuple)):
init_param = (init_param,)
logging.info("init_param is not None: %s", init_param)
for p in init_param:
logging.info(f"Loading pretrained params from {p}")
load_pretrained_model(
model=model,
init_param=p,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
)
else:
initialize(model, kwargs.get("init", "kaiming_normal"))
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
if not isinstance(init_param, (list, tuple)):
init_param = (init_param,)
logging.info("init_param is not None: %s", init_param)
for p in init_param:
logging.info(f"Loading pretrained params from {p}")
load_pretrained_model(
model=model,
init_param=p,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
)
else:
initialize(model, kwargs.get("init", "kaiming_normal"))
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
freeze_param = eval(freeze_param)
if isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
freeze_param = eval(freeze_param)
if isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
if use_ddp:
model = model.cuda(local_rank)
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
model = FSDP(model).cuda(local_rank)
else:
model = model.to(device=kwargs.get("device", "cuda"))
# optim
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if use_ddp:
model = model.cuda(local_rank)
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
model = FSDP(model).cuda(local_rank)
else:
model = model.to(device=kwargs.get("device", "cuda"))
# optim
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
# import pdb;
# pdb.set_trace()
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
# import pdb;
# pdb.set_trace()
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
if batch_sampler is not None:
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
if batch_sampler is not None:
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
trainer = Trainer(
model=model,
optim=optim,
scheduler=scheduler,
dataloader_train=dataloader_tr,
dataloader_val=None,
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
**kwargs.get("train_conf"),
)
trainer.run()
if use_ddp or use_fsdp:
torch.distributed.destroy_process_group()
trainer = Trainer(
model=model,
optim=optim,
scheduler=scheduler,
dataloader_train=dataloader_tr,
dataloader_val=None,
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
**kwargs.get("train_conf"),
)
trainer.run()
if use_ddp or use_fsdp:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main_hydra()
main_hydra()

View File

@ -1,102 +1,93 @@
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
import torchaudio
import time
import logging
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank
@tables.register("dataset_classes", "AudioDataset")
class AudioDataset(torch.utils.data.Dataset):
"""
AudioDataset
"""
def __init__(self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs):
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
self.index_ds = index_ds_class(path)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
"""
AudioDataset
"""
def __init__(self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs):
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
self.index_ds = index_ds_class(path)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
self.int_pad_value = int_pad_value
self.float_pad_value = float_pad_value
def get_source_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
def get_target_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
def __len__(self):
return len(self.index_ds)
def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
self.int_pad_value = int_pad_value
self.float_pad_value = float_pad_value
def get_source_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
def get_target_len(self, index):
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
def __len__(self):
return len(self.index_ds)
def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
return {"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
}
def collator(self, samples: list=None):
return {"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
}
def collator(self, samples: list=None):
outputs = {}
for sample in samples:
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if data_list[0].dtype == torch.int64:
outputs = {}
for sample in samples:
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if data_list[0].dtype == torch.int64:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
return outputs
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
return outputs

View File

@ -1,64 +1,64 @@
import torch
import json
import torch.distributed as dist
import time
import torch
import logging
import torch.distributed as dist
from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonl")
class IndexDSJsonl(torch.utils.data.Dataset):
def __init__(self, path):
super().__init__()
contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
data = json.loads(line.strip())
if "text" in data: # for sft
self.contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data["prompt"]
source = data["source"]
target = data["target"]
source_len = data["source_len"]
target_len = data["target_len"]
def __init__(self, path):
super().__init__()
contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
data = json.loads(line.strip())
if "text" in data: # for sft
self.contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data["prompt"]
source = data["source"]
target = data["target"]
source_len = data["source_len"]
target_len = data["target_len"]
contents.append({"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
)
self.contents = []
total_num = len(contents)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
logging.warning("distributed is not initialized, only single shard")
num_per_rank = total_num // world_size
# rank = 0
# import ipdb; ipdb.set_trace()
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
contents.append({"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
)
self.contents = []
total_num = len(contents)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
logging.warning("distributed is not initialized, only single shard")
num_per_rank = total_num // world_size
# rank = 0
# import ipdb; ipdb.set_trace()
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
return self.contents[index]
def get_source_len(self, data_dict):
return data_dict["source_len"]
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
return self.contents[index]
def get_source_len(self, data_dict):
return data_dict["source_len"]
def get_target_len(self, data_dict):
return data_dict["target_len"] if "target_len" in data_dict else 0
def get_target_len(self, data_dict):
return data_dict["target_len"] if "target_len" in data_dict else 0

View File

@ -1,5 +1,4 @@
import torch
import numpy as np
from funasr.register import tables
@ -7,74 +6,74 @@ from funasr.register import tables
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = False,
shuffle: bool = True,
**kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = batch_size
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle
def __len__(self):
return self.total_samples
def set_epoch(self, epoch):
np.random.seed(epoch)
def __iter__(self):
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.get_source_len(idx_map) + \
self.dataset.get_target_len(idx_map)
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type == 'length':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1
def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = False,
shuffle: bool = True,
**kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = batch_size
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle
def __len__(self):
return self.total_samples
def set_epoch(self, epoch):
np.random.seed(epoch)
def __iter__(self):
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.get_source_len(idx_map) + \
self.dataset.get_target_len(idx_map)
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type == 'length':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1

View File

@ -1,110 +1,111 @@
import json
import os
import json
from omegaconf import OmegaConf
import torch
from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
def download_model(**kwargs):
model_hub = kwargs.get("model_hub", "ms")
if model_hub == "ms":
kwargs = download_from_ms(**kwargs)
return kwargs
model_hub = kwargs.get("model_hub", "ms")
if model_hub == "ms":
kwargs = download_from_ms(**kwargs)
return kwargs
def download_from_ms(**kwargs):
model_or_path = kwargs.get("model")
if model_or_path in name_maps_ms:
model_or_path = name_maps_ms[model_or_path]
model_revision = kwargs.get("model_revision")
if not os.path.exists(model_or_path):
model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
kwargs["model_path"] = model_or_path
config = os.path.join(model_or_path, "config.yaml")
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
config = OmegaConf.load(config)
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
kwargs["model"] = config["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
conf_json = json.load(f)
cfg = {}
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
cfg.update(kwargs)
config = OmegaConf.load(cfg["config"])
kwargs = OmegaConf.merge(config, cfg)
kwargs["model"] = config["model"]
return OmegaConf.to_container(kwargs, resolve=True)
model_or_path = kwargs.get("model")
if model_or_path in name_maps_ms:
model_or_path = name_maps_ms[model_or_path]
model_revision = kwargs.get("model_revision")
if not os.path.exists(model_or_path):
model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
kwargs["model_path"] = model_or_path
config = os.path.join(model_or_path, "config.yaml")
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
config = OmegaConf.load(config)
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
kwargs["model"] = config["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
conf_json = json.load(f)
cfg = {}
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
cfg.update(kwargs)
config = OmegaConf.load(cfg["config"])
kwargs = OmegaConf.merge(config, cfg)
kwargs["model"] = config["model"]
return OmegaConf.to_container(kwargs, resolve=True)
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
if isinstance(file_path_metas, dict):
for k, v in file_path_metas.items():
if isinstance(v, str):
p = os.path.join(model_or_path, v)
if os.path.exists(p):
cfg[k] = p
elif isinstance(v, dict):
if k not in cfg:
cfg[k] = {}
return add_file_root_path(model_or_path, v, cfg[k])
return cfg
if isinstance(file_path_metas, dict):
for k, v in file_path_metas.items():
if isinstance(v, str):
p = os.path.join(model_or_path, v)
if os.path.exists(p):
cfg[k] = p
elif isinstance(v, dict):
if k not in cfg:
cfg[k] = {}
return add_file_root_path(model_or_path, v, cfg[k])
return cfg
def get_or_download_model_dir(
model,
model_revision=None,
is_training=False,
check_latest=True,
):
""" Get local model directory or download model if necessary.
model,
model_revision=None,
is_training=False,
check_latest=True,
):
""" Get local model directory or download model if necessary.
Args:
model (str): model id or path to local model directory.
model_revision (str, optional): model version number.
:param is_training:
"""
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
Args:
model (str): model id or path to local model directory.
model_revision (str, optional): model version number.
:param is_training:
"""
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import Invoke, ThirdParty
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
if os.path.exists(model) and check_latest:
model_cache_dir = model if os.path.isdir(
model) else os.path.dirname(model)
try:
check_local_model_is_latest(
model_cache_dir,
user_agent={
Invoke.KEY: key,
ThirdParty.KEY: "funasr"
})
except:
print("could not check the latest version")
else:
model_cache_dir = snapshot_download(
model,
revision=model_revision,
user_agent={
Invoke.KEY: key,
ThirdParty.KEY: "funasr"
})
return model_cache_dir
from modelscope.utils.constant import Invoke, ThirdParty
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
if os.path.exists(model) and check_latest:
model_cache_dir = model if os.path.isdir(
model) else os.path.dirname(model)
try:
check_local_model_is_latest(
model_cache_dir,
user_agent={
Invoke.KEY: key,
ThirdParty.KEY: "funasr"
})
except:
print("could not check the latest version")
else:
model_cache_dir = snapshot_download(
model,
revision=model_revision,
user_agent={
Invoke.KEY: key,
ThirdParty.KEY: "funasr"
})
return model_cache_dir

View File

@ -1,45 +1,47 @@
from pathlib import Path
import os
import argparse
from pathlib import Path
from funasr.utils.types import str2bool
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
args = parser.parse_args()
model_dir = args.model_name
if not Path(args.model_name).exists():
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
(model_dir)
if args.export:
model_file = os.path.join(model_dir, 'model.onnx')
if args.quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.bin.export_model import ModelExport
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=True,
device="cpu",
quant=args.quantize,
)
export_model.export(model_dir)
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
args = parser.parse_args()
model_dir = args.model_name
if not Path(args.model_name).exists():
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
(model_dir)
if args.export:
model_file = os.path.join(model_dir, 'model.onnx')
if args.quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.bin.export_model import ModelExport
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=True,
device="cpu",
quant=args.quantize,
)
export_model.export(model_dir)
if __name__ == "__main__":
main()
main()

View File

@ -5,12 +5,12 @@ from funasr.register import tables
@tables.register("model_classes", "Branchformer")
class Branchformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
*args,
**kwargs,
):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

View File

@ -7,13 +7,13 @@ from funasr.register import tables
@tables.register("model_classes", "Conformer")
class Conformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
*args,
**kwargs,
):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

View File

@ -5,12 +5,12 @@ from funasr.register import tables
@tables.register("model_classes", "EBranchformer")
class EBranchformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
*args,
**kwargs,
):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

View File

@ -7,12 +7,12 @@ from funasr.register import tables
@tables.register("model_classes", "SANM")
class SANM(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
*args,
**kwargs,
):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

View File

@ -1,289 +1,287 @@
import math
import torch
import numpy as np
import math
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import logging
import torch.nn.functional as F
from funasr.models.scama.utils import sequence_mask
from funasr.models.scama.utils import sequence_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class overlap_chunk():
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(self,
chunk_size: tuple = (16,),
stride: tuple = (10,),
pad_left: tuple = (0,),
encoder_att_look_back_factor: tuple = (1,),
"""
def __init__(self,
chunk_size: tuple = (16,),
stride: tuple = (10,),
pad_left: tuple = (0,),
encoder_att_look_back_factor: tuple = (1,),
shfit_fsmn: int = 0,
decoder_att_look_back_factor: tuple = (1,),
):
):
pad_left = self.check_chunk_size_args(chunk_size, pad_left)
encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
self.shfit_fsmn = shfit_fsmn
self.x_add_mask = None
self.x_rm_mask = None
self.x_len = None
self.mask_shfit_chunk = None
self.mask_chunk_predictor = None
self.mask_att_chunk_encoder = None
self.mask_shift_att_chunk_decoder = None
self.chunk_outs = None
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
= None, None, None, None, None
pad_left = self.check_chunk_size_args(chunk_size, pad_left)
encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
self.shfit_fsmn = shfit_fsmn
self.x_add_mask = None
self.x_rm_mask = None
self.x_len = None
self.mask_shfit_chunk = None
self.mask_chunk_predictor = None
self.mask_att_chunk_encoder = None
self.mask_shift_att_chunk_decoder = None
self.chunk_outs = None
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
= None, None, None, None, None
def check_chunk_size_args(self, chunk_size, x):
if len(x) < len(chunk_size):
x = [x[0] for i in chunk_size]
return x
def check_chunk_size_args(self, chunk_size, x):
if len(x) < len(chunk_size):
x = [x[0] for i in chunk_size]
return x
def get_chunk_size(self,
ind: int = 0
):
# with torch.no_grad:
chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
def get_chunk_size(self,
ind: int = 0
):
# with torch.no_grad:
chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
def random_choice(self, training=True, decoding_ind=None):
chunk_num = len(self.chunk_size)
ind = 0
if training and chunk_num > 1:
ind = torch.randint(0, chunk_num, ()).cpu().item()
if not training and decoding_ind is not None:
ind = int(decoding_ind)
def random_choice(self, training=True, decoding_ind=None):
chunk_num = len(self.chunk_size)
ind = 0
if training and chunk_num > 1:
ind = torch.randint(0, chunk_num, ()).cpu().item()
if not training and decoding_ind is not None:
ind = int(decoding_ind)
return ind
return ind
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
with torch.no_grad():
x_len = x_len.cpu().numpy()
x_len_max = x_len.max()
with torch.no_grad():
x_len = x_len.cpu().numpy()
x_len_max = x_len.max()
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
shfit_fsmn = self.shfit_fsmn
pad_right = chunk_size - stride - pad_left
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
shfit_fsmn = self.shfit_fsmn
pad_right = chunk_size - stride - pad_left
chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
x_len_chunk = x_len_chunk.astype(x_len.dtype)
x_len_chunk_max = x_len_chunk.max()
chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
x_len_chunk = x_len_chunk.astype(x_len.dtype)
x_len_chunk_max = x_len_chunk.max()
chunk_num = int(math.ceil(x_len_max/stride))
dtype = np.int32
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
for chunk_ids in range(chunk_num):
# x_mask add
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
chunk_num = int(math.ceil(x_len_max/stride))
dtype = np.int32
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
for chunk_ids in range(chunk_num):
# x_mask add
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
# x_mask rm
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
# x_mask rm
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
# fsmn_padding_mask
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
# fsmn_padding_mask
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
# predictor mask
zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
# predictor mask
zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
# encoder att mask
zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
# encoder att mask
zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
ones_2_mid = np.ones([stride, stride], dtype=dtype)
zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
ones_2_mid = np.ones([stride, stride], dtype=dtype)
zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
# decoder fsmn_shift_att_mask
zeros_1 = np.zeros([shfit_fsmn, 1])
ones_1 = np.ones([chunk_size, 1])
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
mask_shift_att_chunk_decoder = np.concatenate(
[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
# decoder fsmn_shift_att_mask
zeros_1 = np.zeros([shfit_fsmn, 1])
ones_1 = np.ones([chunk_size, 1])
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
mask_shift_att_chunk_decoder = np.concatenate(
[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
self.x_len_chunk = x_len_chunk
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
self.x_len = x_len
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
self.chunk_outs = (self.x_add_mask,
self.x_len_chunk,
self.x_rm_mask,
self.x_len,
self.mask_shfit_chunk,
self.mask_chunk_predictor,
self.mask_att_chunk_encoder,
self.mask_shift_att_chunk_decoder)
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
self.x_len_chunk = x_len_chunk
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
self.x_len = x_len
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
self.chunk_outs = (self.x_add_mask,
self.x_len_chunk,
self.x_rm_mask,
self.x_len,
self.mask_shfit_chunk,
self.mask_chunk_predictor,
self.mask_att_chunk_encoder,
self.mask_shift_att_chunk_decoder)
return self.chunk_outs
return self.chunk_outs
def split_chunk(self, x, x_len, chunk_outs):
"""
:param x: (b, t, d)
:param x_length: (b)
:param ind: int
:return:
"""
x = x[:, :x_len.max(), :]
b, t, d = x.size()
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
x.device)
x *= x_len_mask[:, :, None]
def split_chunk(self, x, x_len, chunk_outs):
"""
:param x: (b, t, d)
:param x_length: (b)
:param ind: int
:return:
"""
x = x[:, :x_len.max(), :]
b, t, d = x.size()
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
x.device)
x *= x_len_mask[:, :, None]
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
pad = (0, 0, self.pad_left_cur, 0)
x = F.pad(x, pad, "constant", 0.0)
b, t, d = x.size()
x = torch.transpose(x, 1, 0)
x = torch.reshape(x, [t, -1])
x_chunk = torch.mm(x_add_mask, x)
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
pad = (0, 0, self.pad_left_cur, 0)
x = F.pad(x, pad, "constant", 0.0)
b, t, d = x.size()
x = torch.transpose(x, 1, 0)
x = torch.reshape(x, [t, -1])
x_chunk = torch.mm(x_add_mask, x)
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
return x_chunk, x_len_chunk
return x_chunk, x_len_chunk
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
b, t, d = x_chunk.size()
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
x_chunk.device)
x_chunk *= x_len_chunk_mask[:, :, None]
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
b, t, d = x_chunk.size()
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
x_chunk.device)
x_chunk *= x_len_chunk_mask[:, :, None]
x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
x_chunk = torch.transpose(x_chunk, 1, 0)
x_chunk = torch.reshape(x_chunk, [t, -1])
x = torch.mm(x_rm_mask, x_chunk)
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
x_chunk = torch.transpose(x_chunk, 1, 0)
x_chunk = torch.reshape(x_chunk, [t, -1])
x = torch.mm(x_rm_mask, x_chunk)
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
return x, x_len
return x, x_len
def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
with torch.no_grad():
x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
x = torch.from_numpy(x).type(dtype).to(device)
return x
def build_scama_mask_for_cross_attention_decoder(
predictor_alignments: torch.Tensor,
predictor_alignments: torch.Tensor,
encoder_sequence_length: torch.Tensor,
chunk_size: int = 5,
encoder_chunk_size: int = 5,
@ -291,100 +289,100 @@ def build_scama_mask_for_cross_attention_decoder(
attention_chunk_size: int = 1,
attention_chunk_type: str = 'chunk',
step=None,
predictor_mask_chunk_hopping: torch.Tensor = None,
decoder_att_look_back_factor: int = 1,
mask_shift_att_chunk_decoder: torch.Tensor = None,
target_length: torch.Tensor = None,
is_training=True,
predictor_mask_chunk_hopping: torch.Tensor = None,
decoder_att_look_back_factor: int = 1,
mask_shift_att_chunk_decoder: torch.Tensor = None,
target_length: torch.Tensor = None,
is_training=True,
dtype: torch.dtype = torch.float32):
with torch.no_grad():
device = predictor_alignments.device
batch_size, chunk_num = predictor_alignments.size()
maximum_encoder_length = encoder_sequence_length.max().item()
int_type = predictor_alignments.dtype
if not is_training:
target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
maximum_target_length = target_length.max()
predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, chunk_num)
index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div == 0
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
index_div_bool_zeros_count *= chunk_size
index_div_bool_zeros_count += attention_chunk_center_bias
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
index_div_bool_zeros_count_ori = index_div_bool_zeros_count
index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
mask_flip, mask_flip2 = None, None
if attention_chunk_size is not None:
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
mask_flip = 1 - index_div_bool_zeros_count_beg_mask
attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
if predictor_mask_chunk_hopping is not None:
b, k, t = mask.size()
predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
mask_mask_flip = mask
if mask_flip is not None:
mask_mask_flip = mask_flip * mask
def _fn():
mask_sliced = mask[:b, :k, encoder_chunk_size:t]
zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
_, _, tt = predictor_mask_chunk_hopping.size()
pad_right_p = max_len_chunk - tt
predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
masked = mask_sliced * predictor_mask_chunk_hopping_pad
mask_true = mask_mask_flip + masked
return mask_true
mask = _fn() if t > chunk_size else mask_mask_flip
if mask_flip2 is not None:
mask *= mask_flip2
mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
if attention_chunk_type == 'full':
mask = torch.ones_like(mask).to(device)
if mask_shift_att_chunk_decoder is not None:
mask = mask * mask_shift_att_chunk_decoder
mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
with torch.no_grad():
device = predictor_alignments.device
batch_size, chunk_num = predictor_alignments.size()
maximum_encoder_length = encoder_sequence_length.max().item()
int_type = predictor_alignments.dtype
if not is_training:
target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
maximum_target_length = target_length.max()
predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, chunk_num)
index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
index_div_bool_zeros = index_div == 0
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
index_div_bool_zeros_count *= chunk_size
index_div_bool_zeros_count += attention_chunk_center_bias
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
index_div_bool_zeros_count_ori = index_div_bool_zeros_count
index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
mask_flip, mask_flip2 = None, None
if attention_chunk_size is not None:
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
mask_flip = 1 - index_div_bool_zeros_count_beg_mask
attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
if predictor_mask_chunk_hopping is not None:
b, k, t = mask.size()
predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
mask_mask_flip = mask
if mask_flip is not None:
mask_mask_flip = mask_flip * mask
def _fn():
mask_sliced = mask[:b, :k, encoder_chunk_size:t]
zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
_, _, tt = predictor_mask_chunk_hopping.size()
pad_right_p = max_len_chunk - tt
predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
masked = mask_sliced * predictor_mask_chunk_hopping_pad
mask_true = mask_mask_flip + masked
return mask_true
mask = _fn() if t > chunk_size else mask_mask_flip
if mask_flip2 is not None:
mask *= mask_flip2
mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
if attention_chunk_type == 'full':
mask = torch.ones_like(mask).to(device)
if mask_shift_att_chunk_decoder is not None:
mask = mask * mask_shift_att_chunk_decoder
mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
return mask
return mask

View File

@ -1,29 +1,30 @@
import os
import torch
from torch.nn import functional as F
import yaml
import torch
import numpy as np
from torch.nn import functional as F
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def apply_cmvn(inputs, mvn):
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
meams = np.tile(mvn[0:1, :dim], (frame, 1))
vars = np.tile(mvn[1:2, :dim], (frame, 1))
inputs -= torch.from_numpy(meams).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
meams = np.tile(mvn[0:1, :dim], (frame, 1))
vars = np.tile(mvn[1:2, :dim], (frame, 1))
inputs -= torch.from_numpy(meams).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
return inputs.type(torch.float32)
@ -36,56 +37,56 @@ def drop_and_add(inputs: torch.Tensor,
outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
outputs *= stoch_layer_coeff
outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
outputs *= stoch_layer_coeff
input_dim = inputs.size(-1)
output_dim = outputs.size(-1)
input_dim = inputs.size(-1)
output_dim = outputs.size(-1)
if input_dim == output_dim:
outputs += inputs
return outputs
if input_dim == output_dim:
outputs += inputs
return outputs
def proc_tf_vocab(vocab_path):
with open(vocab_path, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
if '<unk>' not in token_list:
token_list.append('<unk>')
return token_list
with open(vocab_path, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
if '<unk>' not in token_list:
token_list.append('<unk>')
return token_list
def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
token_list = proc_tf_vocab(vocab_path)
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f)
config['token_list'] = token_list
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
token_list = proc_tf_vocab(vocab_path)
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f)
config['token_list'] = token_list
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
class NoAliasSafeDumper(yaml.SafeDumper):
# Disable anchor/alias in yaml because looks ugly
def ignore_aliases(self, data):
return True
# Disable anchor/alias in yaml because looks ugly
def ignore_aliases(self, data):
return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
"""Safe-dump in yaml with no anchor/alias"""
return yaml.dump(
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
)
"""Safe-dump in yaml with no anchor/alias"""
return yaml.dump(
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
)
if __name__ == '__main__':
import sys
config_path = sys.argv[1]
vocab_path = sys.argv[2]
output_dir = sys.argv[3]
gen_config_for_tfmodel(config_path, vocab_path, output_dir)
import sys
config_path = sys.argv[1]
vocab_path = sys.argv[2]
output_dir = sys.argv[3]
gen_config_for_tfmodel(config_path, vocab_path, output_dir)

View File

@ -541,20 +541,20 @@ class UniASR(FunASRModel):
speech_lengths: (Batch, )
"""
# with autocast(False):
# # 1. Extract feats
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# # 1. Extract feats
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
#
# # 2. Data augmentation
# if self.specaug is not None and self.training:
# feats, feats_lengths = self.specaug(feats, feats_lengths)
# # 2. Data augmentation
# if self.specaug is not None and self.training:
# feats, feats_lengths = self.specaug(feats, feats_lengths)
#
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
# if self.normalize is not None:
# feats, feats_lengths = self.normalize(feats, feats_lengths)
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
# if self.normalize is not None:
# feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
# if self.preencoder is not None:
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out,
encoder_out_lens,
@ -584,9 +584,9 @@ class UniASR(FunASRModel):
# # Post-encoder, e.g. NLU
# if self.postencoder is not None:
# encoder_out, encoder_out_lens = self.postencoder(
# encoder_out, encoder_out_lens
# )
# encoder_out, encoder_out_lens = self.postencoder(
# encoder_out, encoder_out_lens
# )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),

View File

@ -3,15 +3,15 @@ from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
optim_classes = dict(
adam=torch.optim.Adam,
fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
adagrad=torch.optim.Adagrad,
adamax=torch.optim.Adamax,
asgd=torch.optim.ASGD,
lbfgs=torch.optim.LBFGS,
rmsprop=torch.optim.RMSprop,
rprop=torch.optim.Rprop,
adam=torch.optim.Adam,
fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
adagrad=torch.optim.Adagrad,
adamax=torch.optim.Adamax,
asgd=torch.optim.ASGD,
lbfgs=torch.optim.LBFGS,
rmsprop=torch.optim.RMSprop,
rprop=torch.optim.Rprop,
)

View File

@ -8,16 +8,16 @@ from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
steplr=torch.optim.lr_scheduler.StepLR,
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
steplr=torch.optim.lr_scheduler.StepLR,
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)

View File

@ -1,100 +1,94 @@
from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import json
import numpy as np
from abc import ABC
from pathlib import Path
from abc import abstractmethod
from typing import Union, Iterable, List, Dict
class AbsTokenizer(ABC):
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
class BaseTokenizer(ABC):
def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
unk_symbol: str = "<unk>",
**kwargs,
):
if token_list is not None:
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with token_list.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with open(token_list, 'r', encoding='utf-8') as f:
self.token_list = json.load(f)
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
for i, t in enumerate(self.token_list):
if i == 3:
break
self.token_list_repr += f"{t}, "
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
self.token2id: Dict[str, int] = {}
for i, t in enumerate(self.token_list):
if t in self.token2id:
raise RuntimeError(f'Symbol "{t}" is duplicated')
self.token2id[t] = i
self.unk_symbol = unk_symbol
if self.unk_symbol not in self.token2id:
raise RuntimeError(
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
)
self.unk_id = self.token2id[self.unk_symbol]
def encode(self, text):
tokens = self.text2tokens(text)
text_ints = self.tokens2ids(tokens)
return text_ints
def decode(self, text_ints):
token = self.ids2tokens(text_ints)
text = self.tokens2text(token)
return text
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
unk_symbol: str = "<unk>",
**kwargs,
):
if token_list is not None:
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with token_list.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with open(token_list, 'r', encoding='utf-8') as f:
self.token_list = json.load(f)
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
for i, t in enumerate(self.token_list):
if i == 3:
break
self.token_list_repr += f"{t}, "
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
self.token2id: Dict[str, int] = {}
for i, t in enumerate(self.token_list):
if t in self.token2id:
raise RuntimeError(f'Symbol "{t}" is duplicated')
self.token2id[t] = i
self.unk_symbol = unk_symbol
if self.unk_symbol not in self.token2id:
raise RuntimeError(
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
)
self.unk_id = self.token2id[self.unk_symbol]
def encode(self, text):
tokens = self.text2tokens(text)
text_ints = self.tokens2ids(tokens)
return text_ints
def decode(self, text_ints):
token = self.ids2tokens(text_ints)
text = self.tokens2text(token)
return text
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError

View File

@ -1,233 +1,235 @@
import torch
import os
from funasr.train_utils.device_funcs import to_device
import logging
import time
import torch
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from contextlib import nullcontext
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(self, model,
optim,
scheduler,
dataloader_train,
dataloader_val,
local_rank,
use_ddp=False,
use_fsdp=False,
**kwargs):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(self, model,
optim,
scheduler,
dataloader_train,
dataloader_val,
local_rank,
use_ddp=False,
use_fsdp=False,
**kwargs):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Args:
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
**kwargs: Additional keyword arguments:
max_epoch (int): The maximum number of epochs for training.
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.model = model
self.optim = optim
self.scheduler = scheduler
self.dataloader_train = dataloader_train
self.dataloader_val = dataloader_val
self.output_dir = kwargs.get('output_dir', './')
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = next(model.parameters()).device
self.kwargs = kwargs
if self.resume:
self._resume_checkpoint(self.resume)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
def _save_checkpoint(self, epoch):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
**kwargs: Additional keyword arguments:
max_epoch (int): The maximum number of epochs for training.
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.model = model
self.optim = optim
self.scheduler = scheduler
self.dataloader_train = dataloader_train
self.dataloader_val = dataloader_val
self.output_dir = kwargs.get('output_dir', './')
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = next(model.parameters()).device
self.kwargs = kwargs
if self.resume:
self._resume_checkpoint(self.resume)
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
def _save_checkpoint(self, epoch):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
state = {
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
torch.save(state, filename)
print(f'Checkpoint saved to {filename}')
def _resume_checkpoint(self, resume_path):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
state = {
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
torch.save(state, filename)
print(f'Checkpoint saved to {filename}')
def _resume_checkpoint(self, resume_path):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if os.path.isfile(resume_path):
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.model.load_state_dict(checkpoint['state_dict'])
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
else:
print(f"No checkpoint found at '{resume_path}', starting from scratch")
def run(self):
"""
Starts the training process, iterating over epochs, training the model,
and saving checkpoints at the end of each epoch.
"""
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
if self.rank == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
def _train_epoch(self, epoch):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
self.model.train()
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
dynamic_ncols=True)
# Set the number of steps for gradient accumulation
accum_grad = self.kwargs.get("accum_grad", 1)
# Initialize the gradient accumulation
self.optim.zero_grad()
speed_stats = {}
time5 = time.perf_counter()
for batch_idx, batch in enumerate(self.dataloader_train):
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
# import pdb;
# pdb.set_trace()
batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
retval = self.model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed=True)
# Now weight is summation over all workers
loss /= weight
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
# Perform gradient clipping if it is set
if self.kwargs.get("grad_clip", None) is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.kwargs.get("grad_clip", 10.0),
norm_type=self.kwargs.get("grad_clip_type", 2.0),
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
self.optim.zero_grad() # Reset gradients
continue
# Execute an optimization step (update model parameters)
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
self.optim.zero_grad()
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
# import pdb;
# pdb.set_trace()
pbar.update(1)
if self.local_rank == 0:
description = (
f"Epoch: {epoch + 1}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
# if batch_idx == 2:
# break
pbar.close()
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if os.path.isfile(resume_path):
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.model.load_state_dict(checkpoint['state_dict'])
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
else:
print(f"No checkpoint found at '{resume_path}', starting from scratch")
def run(self):
"""
Starts the training process, iterating over epochs, training the model,
and saving checkpoints at the end of each epoch.
"""
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
if self.rank == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
def _train_epoch(self, epoch):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
self.model.train()
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
dynamic_ncols=True)
# Set the number of steps for gradient accumulation
accum_grad = self.kwargs.get("accum_grad", 1)
# Initialize the gradient accumulation
self.optim.zero_grad()
speed_stats = {}
time5 = time.perf_counter()
for batch_idx, batch in enumerate(self.dataloader_train):
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
# import pdb;
# pdb.set_trace()
batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
retval = self.model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed=True)
# Now weight is summation over all workers
loss /= weight
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
# Perform gradient clipping if it is set
if self.kwargs.get("grad_clip", None) is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.kwargs.get("grad_clip", 10.0),
norm_type=self.kwargs.get("grad_clip_type", 2.0),
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
self.optim.zero_grad() # Reset gradients
continue
# Execute an optimization step (update model parameters)
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
self.optim.zero_grad()
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
# import pdb;
# pdb.set_trace()
pbar.update(1)
if self.local_rank == 0:
description = (
f"Epoch: {epoch + 1}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
# if batch_idx == 2:
# break
pbar.close()
def _validate_epoch(self, epoch):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
self.model.eval()
with torch.no_grad():
for data, target in self.dataloader_val:
# Implement the model validation steps here
pass
def _validate_epoch(self, epoch):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
self.model.eval()
with torch.no_grad():
for data, target in self.dataloader_val:
# Implement the model validation steps here
pass

View File

@ -10,100 +10,100 @@ import time
import logging
from torch.nn.utils.rnn import pad_sequence
try:
from funasr.download.file import download_from_url
from funasr.download.file import download_from_url
except:
print("urllib is not installed, if you infer from url, please install it first.")
print("urllib is not installed, if you infer from url, please install it first.")
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
data_or_path_or_list = data_or_path_or_list[0, :]
elif data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif data_type == "image": # undo
pass
elif data_type == "video": # undo
pass
# if data_in is a file or url, set is_final=True
if "cache" in kwargs:
kwargs["cache"]["is_final"] = True
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
return data_or_path_or_list
return data_or_path_or_list_ret
else:
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
data_or_path_or_list = data_or_path_or_list[0, :]
elif data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif data_type == "image": # undo
pass
elif data_type == "video": # undo
pass
# if data_in is a file or url, set is_final=True
if "cache" in kwargs:
kwargs["cache"]["is_final"] = True
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
return data_or_path_or_list
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)

View File

@ -1,31 +1,31 @@
import torch
from torch.nn.utils.rnn import pad_sequence
def slice_padding_fbank(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0]*16)
end_idx = min(int(segment[0][1]*16), speech_lengths[0])
speech_i = speech[0, bed_idx: end_idx]
speech_lengths_i = end_idx-bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad
def slice_padding_fbank(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0]*16)
end_idx = min(int(segment[0][1]*16), speech_lengths[0])
speech_i = speech[0, bed_idx: end_idx]
speech_lengths_i = end_idx-bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths)
speech_i = speech[bed_idx: end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
return speech_list, speech_lengths_list
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths)
speech_i = speech[bed_idx: end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
return speech_list, speech_lengths_list

View File

@ -17,8 +17,8 @@ args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@ -26,23 +26,23 @@ wav_files = wav_file_f.readlines()
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
os.makedirs(output_dir)
if os.name == 'nt': # Windows
newline = '\r\n'
newline = '\r\n'
else: # Linux Mac
newline = '\n'
newline = '\n'
text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
for i, wav_path_i in enumerate(wav_files):
wav_name, wav_path = wav_path_i.strip().split()
result = model(wav_path)
text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
text_f.write(text_i)
text_f.flush()
token_f.write(token_i)
token_f.flush()
wav_name, wav_path = wav_path_i.strip().split()
result = model(wav_path)
text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
text_f.write(text_i)
text_f.flush()
token_f.write(token_i)
token_f.flush()
text_f.close()
token_f.close()

View File

@ -16,8 +16,8 @@ args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@ -28,28 +28,28 @@ total = 0.0
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
beg_time = time.time()
result = model(wav_path)
end_time = time.time()
duration = end_time-beg_time
total += duration
print(result)
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
beg_time = time.time()
result = model(wav_path)
end_time = time.time()
duration = end_time-beg_time
total += duration
print(result)
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
result = model(wav_path)
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
waveform, _ = librosa.load(wav_path, sr=16000)
duration_time += len(waveform)/16.0
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
waveform, _ = librosa.load(wav_path, sr=16000)
duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))

View File

@ -17,8 +17,8 @@ args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
@ -29,20 +29,20 @@ total = 0.0
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
beg_time = time.time()
result = model(wav_path)
end_time = time.time()
duration = end_time-beg_time
total += duration
print(result)
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
beg_time = time.time()
result = model(wav_path)
end_time = time.time()
duration = end_time-beg_time
total += duration
print(result)
print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
wav_path = []
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
wav_path += [wav_path_i]
wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
wav_path += [wav_path_i]
result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
@ -50,9 +50,9 @@ print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
waveform, _ = librosa.load(wav_path, sr=16000)
duration_time += len(waveform)/16.0
wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
waveform, _ = librosa.load(wav_path, sr=16000)
duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))