diff --git a/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py b/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py index f83c572d8..3b6373c78 100755 --- a/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py +++ b/egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py @@ -1,7 +1,10 @@ import argparse import tqdm import codecs -import textgrid +try: + import textgrid +except: + raise "Please install textgrid firstly: pip install textgrid" import pdb class Segment(object): diff --git a/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py b/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py index 1b09d0af7..8dc98903c 100755 --- a/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py +++ b/egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py @@ -6,7 +6,10 @@ import argparse import codecs from distutils.util import strtobool from pathlib import Path -import textgrid +try: + import textgrid +except: + raise "Please install textgrid firstly: pip install textgrid" import pdb class Segment(object): diff --git a/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py index 8ece75706..769003d30 100755 --- a/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py +++ b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py @@ -6,7 +6,10 @@ import argparse import codecs from distutils.util import strtobool from pathlib import Path -import textgrid +try: + import textgrid +except: + raise "Please install textgrid firstly: pip install textgrid" import pdb class Segment(object): diff --git a/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py index 81c19659a..b6d01572e 100755 --- a/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py +++ b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py @@ -6,7 +6,10 @@ import argparse import codecs from distutils.util import strtobool from pathlib import Path -import textgrid +try: + import textgrid +except: + raise "Please install textgrid firstly: pip install textgrid" import pdb class Segment(object): diff --git a/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py index 488344fb0..c26ba32fe 100755 --- a/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py +++ b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py @@ -6,7 +6,10 @@ import argparse import codecs from distutils.util import strtobool from pathlib import Path -import textgrid +try: + import textgrid +except: + raise "Please install textgrid firstly: pip install textgrid" import pdb def get_args(): diff --git a/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py index fdf246090..b72ddc9ef 100755 --- a/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py +++ b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py @@ -6,7 +6,12 @@ import argparse import codecs from distutils.util import strtobool from pathlib import Path -import textgrid + +try: + import textgrid +except: + raise "Please install textgrid firstly: pip install textgrid" + import pdb import numpy as np import sys diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 7015eb8e2..a1cede19a 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -34,8 +34,8 @@ from funasr.modules.beam_search.beam_search_transducer import Hypothesis as Hypo from funasr.modules.scorers.ctc import CTCPrefixScorer from funasr.modules.scorers.length_bonus import LengthBonus from funasr.build_utils.build_asr_model import frontend_choices -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.token_id_converter import TokenIDConverter +from funasr.tokenizer.build_tokenizer import build_tokenizer +from funasr.tokenizer.token_id_converter import TokenIDConverter from funasr.torch_utils.device_funcs import to_device from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard @@ -44,9 +44,9 @@ class Speech2Text: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] @@ -251,9 +251,9 @@ class Speech2TextParaformer: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] @@ -625,9 +625,9 @@ class Speech2TextParaformerOnline: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] @@ -876,9 +876,9 @@ class Speech2TextUniASR: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] @@ -1106,9 +1106,9 @@ class Speech2TextMFCCA: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] @@ -1637,9 +1637,9 @@ class Speech2TextSAASR: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] @@ -1885,9 +1885,9 @@ class Speech2TextWhisper: """Speech2Text class Examples: - >>> import soundfile + >>> import librosa >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2text(audio) [(text, token, token_int, hypothesis object), ...] diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index e1a32c57c..7dd27fc71 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -20,7 +20,8 @@ from typing import Union import numpy as np import torch import torchaudio -import soundfile +# import librosa +import librosa import yaml from funasr.bin.asr_infer import Speech2Text @@ -1281,7 +1282,8 @@ def inference_paraformer_online( try: raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] except: - raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0] + # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0] + raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32') if raw_inputs.ndim == 2: raw_inputs = raw_inputs[:, 0] raw_inputs = torch.tensor(raw_inputs) diff --git a/funasr/export/models/language_models/__init__.py b/funasr/bin/asr_trainer.py similarity index 100% rename from funasr/export/models/language_models/__init__.py rename to funasr/bin/asr_trainer.py diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py index bda83ec87..c03bdf39e 100644 --- a/funasr/bin/build_trainer.py +++ b/funasr/bin/build_trainer.py @@ -18,7 +18,7 @@ from funasr.build_utils.build_optimizer import build_optimizer from funasr.build_utils.build_scheduler import build_scheduler from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope from funasr.modules.lora.utils import mark_only_lora_as_trainable -from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.tokenizer.phoneme_tokenizer import g2p_choices from funasr.torch_utils.load_pretrained_model import load_pretrained_model from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.pytorch_version import pytorch_cudnn_version diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py index 6fc1da162..bb40f5e41 100755 --- a/funasr/bin/diar_infer.py +++ b/funasr/bin/diar_infer.py @@ -27,11 +27,11 @@ class Speech2DiarizationEEND: """Speech2Diarlization class Examples: - >>> import soundfile + >>> import librosa >>> import numpy as np >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb") >>> profile = np.load("profiles.npy") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2diar(audio, profile) {"spk1": [(int, int), ...], ...} @@ -109,11 +109,11 @@ class Speech2DiarizationSOND: """Speech2Xvector class Examples: - >>> import soundfile + >>> import librosa >>> import numpy as np >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb") >>> profile = np.load("profiles.npy") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2diar(audio, profile) {"spk1": [(int, int), ...], ...} diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py index b655df544..f5a11b152 100755 --- a/funasr/bin/diar_inference_launch.py +++ b/funasr/bin/diar_inference_launch.py @@ -15,7 +15,8 @@ from typing import Tuple from typing import Union import numpy as np -import soundfile +# import librosa +import librosa import torch from scipy.signal import medfilt @@ -144,7 +145,9 @@ def inference_sond( # read waveform file example = [load_bytes(x) if isinstance(x, bytes) else x for x in example] - example = [soundfile.read(x)[0] if isinstance(x, str) else x + # example = [librosa.load(x)[0] if isinstance(x, str) else x + # for x in example] + example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x for x in example] # convert torch tensor to numpy array example = [x.numpy() if isinstance(example[0], torch.Tensor) else x diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py index 483967b37..a3eca115d 100644 --- a/funasr/bin/ss_infer.py +++ b/funasr/bin/ss_infer.py @@ -20,9 +20,9 @@ class SpeechSeparator: """SpeechSeparator class Examples: - >>> import soundfile + >>> import librosa >>> speech_separator = MossFormer("ss_config.yml", "ss.pt") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> separated_wavs = speech_separator(audio) """ diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py index 64503a0a5..0c0241913 100644 --- a/funasr/bin/ss_inference_launch.py +++ b/funasr/bin/ss_inference_launch.py @@ -13,7 +13,7 @@ from typing import Union import numpy as np import torch -import soundfile as sf +import librosa from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse @@ -104,7 +104,12 @@ def inference_ss( ss_results = speech_separator(**batch) for spk in range(num_spks): - sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate) + # sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate) + try: + librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate) + except: + print("To write wav by librosa, you should install librosa<=0.8.0") + raise torch.cuda.empty_cache() return ss_results diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py index 346440af0..19cfc2e66 100755 --- a/funasr/bin/sv_infer.py +++ b/funasr/bin/sv_infer.py @@ -22,9 +22,9 @@ class Speech2Xvector: """Speech2Xvector class Examples: - >>> import soundfile + >>> import librosa >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2xvector(audio) [(text, token, token_int, hypothesis object), ...] diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py index 6ec83a89e..674c1b977 100755 --- a/funasr/bin/tokenize_text.py +++ b/funasr/bin/tokenize_text.py @@ -9,9 +9,9 @@ from typing import Optional from funasr.utils.cli_utils import get_commandline_args -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.cleaner import TextCleaner -from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.tokenizer.build_tokenizer import build_tokenizer +from funasr.tokenizer.cleaner import TextCleaner +from funasr.tokenizer.phoneme_tokenizer import g2p_choices from funasr.utils.types import str2bool from funasr.utils.types import str_or_none diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py index ede579c7b..cfe534f6c 100644 --- a/funasr/bin/tp_infer.py +++ b/funasr/bin/tp_infer.py @@ -11,7 +11,7 @@ import numpy as np import torch from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.text.token_id_converter import TokenIDConverter +from funasr.tokenizer.token_id_converter import TokenIDConverter from funasr.torch_utils.device_funcs import to_device diff --git a/funasr/bin/train.py b/funasr/bin/train.py index f5d10c4ac..6aebf8a56 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -17,7 +17,7 @@ from funasr.build_utils.build_model import build_model from funasr.build_utils.build_optimizer import build_optimizer from funasr.build_utils.build_scheduler import build_scheduler from funasr.build_utils.build_trainer import build_trainer -from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.tokenizer.phoneme_tokenizer import g2p_choices from funasr.torch_utils.load_pretrained_model import load_pretrained_model from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.pytorch_version import pytorch_cudnn_version diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py index 73e1f3f3b..57638739c 100644 --- a/funasr/bin/vad_infer.py +++ b/funasr/bin/vad_infer.py @@ -23,9 +23,9 @@ class Speech2VadSegment: """Speech2VadSegment class Examples: - >>> import soundfile + >>> import librosa >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2segment(audio) [[10, 230], [245, 450], ...] @@ -118,9 +118,9 @@ class Speech2VadSegmentOnline(Speech2VadSegment): """Speech2VadSegmentOnline class Examples: - >>> import soundfile + >>> import librosa >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt") - >>> audio, rate = soundfile.read("speech.wav") + >>> audio, rate = librosa.load("speech.wav") >>> speech2segment(audio) [[10, 230], [245, 450], ...] diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py index 03aa7802d..498d05d4f 100644 --- a/funasr/build_utils/build_trainer.py +++ b/funasr/build_utils/build_trainer.py @@ -246,14 +246,11 @@ class Trainer: for iepoch in range(start_epoch, trainer_options.max_epoch + 1): if iepoch != start_epoch: logging.info( - "{}/{}epoch started. Estimated time to finish: {}".format( + "{}/{}epoch started. Estimated time to finish: {} hours".format( iepoch, trainer_options.max_epoch, - humanfriendly.format_timespan( - (time.perf_counter() - start_time) - / (iepoch - start_epoch) - * (trainer_options.max_epoch - iepoch + 1) - ), + (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * ( + trainer_options.max_epoch - iepoch + 1), ) ) else: diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py new file mode 100644 index 000000000..6b3407c31 --- /dev/null +++ b/funasr/datasets/data_sampler.py @@ -0,0 +1,74 @@ +import torch + +import numpy as np + +class BatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + # self.batch_size_type = args.batch_size_type + # self.batch_size = args.batch_size + # self.sort_size = args.sort_size + # self.max_length_token = args.max_length_token + self.batch_size_type = batch_size_type + self.batch_size = batch_size + self.sort_size = sort_size + self.max_length_token = kwargs.get("max_length_token", 5000) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle + + + def __len__(self): + return self.total_samples + + def __iter__(self): + print("in sampler") + + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples-1) // self.sort_size + 1 + print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + datalen_with_index = [] + for i in range(self.sort_size): + idx = iter * self.sort_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \ + self.dataset.indexed_dataset[idx_map]["target_len"] + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for item in datalen_with_index_sort: + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_length_token: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + if self.batch_size_type == 'token': + max_token_padding *= max_token_cur + if max_token_padding <= self.batch_size: + batch.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + yield batch + batch = [idx] + max_token = sample_len_cur_raw + num_sample = 1 + + \ No newline at end of file diff --git a/funasr/datasets/dataloader_fn.py b/funasr/datasets/dataloader_fn.py new file mode 100644 index 000000000..8e8e4235f --- /dev/null +++ b/funasr/datasets/dataloader_fn.py @@ -0,0 +1,53 @@ + +import torch +from funasr.datasets.dataset_jsonl import AudioDataset +from funasr.datasets.data_sampler import BatchSampler +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.tokenizer.build_tokenizer import build_tokenizer +from funasr.tokenizer.token_id_converter import TokenIDConverter +collate_fn = None +# collate_fn = collate_fn, + +jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl" + +frontend = WavFrontend() +token_type = 'char' +bpemodel = None +delimiter = None +space_symbol = "" +non_linguistic_symbols = None +g2p_type = None + +tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + g2p_type=g2p_type, +) +token_list = "" +unk_symbol = "" + +token_id_converter = TokenIDConverter( + token_list=token_list, + unk_symbol=unk_symbol, +) + +dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer) +batch_sampler = BatchSampler(dataset) +dataloader_tr = torch.utils.data.DataLoader(dataset, + collate_fn=dataset.collator, + batch_sampler=batch_sampler, + shuffle=False, + num_workers=0, + pin_memory=True) + +print(len(dataset)) +for i in range(3): + print(i) + for data in dataloader_tr: + print(len(data), data) +# data_iter = iter(dataloader_tr) +# data = next(data_iter) +pass diff --git a/funasr/datasets/dataset.py b/funasr/datasets/dataset.py index 407f6aa9e..673a9b2c9 100644 --- a/funasr/datasets/dataset.py +++ b/funasr/datasets/dataset.py @@ -16,8 +16,10 @@ from typing import Dict from typing import Mapping from typing import Tuple from typing import Union - -import h5py +try: + import h5py +except: + print("If you want use h5py dataset, please pip install h5py, and try it again") import humanfriendly import kaldiio import numpy as np diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py new file mode 100644 index 000000000..72d9a99d3 --- /dev/null +++ b/funasr/datasets/dataset_jsonl.py @@ -0,0 +1,124 @@ +import torch +import json +import torch.distributed as dist +import numpy as np +import kaldiio +import librosa + + + +def load_audio(audio_path: str, fs: int=16000): + audio = None + if audio_path.startswith("oss:"): + pass + elif audio_path.startswith("odps:"): + pass + else: + if ".ark:" in audio_path: + audio = kaldiio.load_mat(audio_path) + else: + audio, fs = librosa.load(audio_path, sr=fs) + return audio + +def extract_features(data, date_type: str="sound", frontend=None): + if date_type == "sound": + feat, feats_lens = frontend(data, len(data)) + feat = feat[0, :, :] + else: + feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32) + return feat, feats_lens + + + +class IndexedDatasetJsonl(torch.utils.data.Dataset): + + def __init__(self, path): + super().__init__() + # data_parallel_size = dist.get_world_size() + data_parallel_size = 1 + contents = [] + with open(path, encoding='utf-8') as fin: + for line in fin: + data = json.loads(line.strip()) + if "text" in data: # for sft + self.contents.append(data['text']) + if "source" in data: # for speech lab pretrain + prompt = data["prompt"] + source = data["source"] + target = data["target"] + source_len = data["source_len"] + target_len = data["target_len"] + + contents.append({"source": source, + "prompt": prompt, + "target": target, + "source_len": source_len, + "target_len": target_len, + } + ) + + self.contents = [] + total_num = len(contents) + num_per_rank = total_num // data_parallel_size + # rank = dist.get_rank() + rank = 0 + # import ipdb; ipdb.set_trace() + self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] + + + def __len__(self): + return len(self.contents) + + def __getitem__(self, index): + return self.contents[index] + + +class AudioDataset(torch.utils.data.Dataset): + def __init__(self, path, frontend=None, tokenizer=None): + super().__init__() + self.indexed_dataset = IndexedDatasetJsonl(path) + self.frontend = frontend.forward + self.fs = 16000 if frontend is None else frontend.fs + self.data_type = "sound" + self.tokenizer = tokenizer + self.int_pad_value = -1 + self.float_pad_value = 0.0 + + + + + def __len__(self): + return len(self.indexed_dataset) + + def __getitem__(self, index): + item = self.indexed_dataset[index] + source = item["source"] + data_src = load_audio(source, fs=self.fs) + speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend) + target = item["target"] + text = self.tokenizer.encode(target) + text_lengths = len(text) + text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32) + return {"speech": speech, + "speech_lengths": speech_lengths, + "text": text, + "text_lengths": text_lengths, + } + + + def collator(self, samples: list=None): + + outputs = {} + for sample in samples: + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(sample[key]) + + for key, data_list in outputs.items(): + if data_list[0].dtype.kind == "i": + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) + return samples \ No newline at end of file diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py index 6398e0cf7..b2cc283e1 100644 --- a/funasr/datasets/iterable_dataset.py +++ b/funasr/datasets/iterable_dataset.py @@ -14,7 +14,8 @@ import kaldiio import numpy as np import torch import torchaudio -import soundfile +# import librosa +import librosa from torch.utils.data.dataset import IterableDataset import os.path @@ -70,7 +71,8 @@ def load_wav(input): try: return torchaudio.load(input)[0].numpy() except: - waveform, _ = soundfile.read(input, dtype='float32') + # waveform, _ = librosa.load(input, dtype='float32') + waveform, _ = librosa.load(input, dtype='float32') if waveform.ndim == 2: waveform = waveform[:, 0] return np.expand_dims(waveform, axis=0) diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py index 6c2da2ab5..134b20a65 100644 --- a/funasr/datasets/large_datasets/build_dataloader.py +++ b/funasr/datasets/large_datasets/build_dataloader.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from funasr.datasets.large_datasets.dataset import Dataset from funasr.iterators.abs_iter_factory import AbsIterFactory -from funasr.text.abs_tokenizer import AbsTokenizer +from funasr.tokenizer.abs_tokenizer import AbsTokenizer def read_symbol_table(symbol_table_file): diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py index adfe4f6d9..d3489c199 100644 --- a/funasr/datasets/large_datasets/dataset.py +++ b/funasr/datasets/large_datasets/dataset.py @@ -7,7 +7,8 @@ import torch import torch.distributed as dist import torchaudio import numpy as np -import soundfile +# import librosa +import librosa from kaldiio import ReadHelper from torch.utils.data import IterableDataset @@ -128,7 +129,8 @@ class AudioDataset(IterableDataset): try: waveform, sampling_rate = torchaudio.load(path) except: - waveform, sampling_rate = soundfile.read(path, dtype='float32') + # waveform, sampling_rate = librosa.load(path, dtype='float32') + waveform, sampling_rate = librosa.load(path, dtype='float32') if waveform.ndim == 2: waveform = waveform[:, 0] waveform = np.expand_dims(waveform, axis=0) diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index 9b5c4e77e..b303418bd 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -10,12 +10,12 @@ from typing import Union import numpy as np import scipy.signal -import soundfile +import librosa import jieba -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.cleaner import TextCleaner -from funasr.text.token_id_converter import TokenIDConverter +from funasr.tokenizer.build_tokenizer import build_tokenizer +from funasr.tokenizer.cleaner import TextCleaner +from funasr.tokenizer.token_id_converter import TokenIDConverter class AbsPreprocessor(ABC): @@ -284,7 +284,7 @@ class CommonPreprocessor(AbsPreprocessor): if self.rirs is not None and self.rir_apply_prob >= np.random.random(): rir_path = np.random.choice(self.rirs) if rir_path is not None: - rir, _ = soundfile.read( + rir, _ = librosa.load( rir_path, dtype=np.float64, always_2d=True ) @@ -310,28 +310,31 @@ class CommonPreprocessor(AbsPreprocessor): noise_db = np.random.uniform( self.noise_db_low, self.noise_db_high ) - with soundfile.SoundFile(noise_path) as f: - if f.frames == nsamples: - noise = f.read(dtype=np.float64, always_2d=True) - elif f.frames < nsamples: - offset = np.random.randint(0, nsamples - f.frames) - # noise: (Time, Nmic) - noise = f.read(dtype=np.float64, always_2d=True) - # Repeat noise - noise = np.pad( - noise, - [(offset, nsamples - f.frames - offset), (0, 0)], - mode="wrap", - ) - else: - offset = np.random.randint(0, f.frames - nsamples) - f.seek(offset) - # noise: (Time, Nmic) - noise = f.read( - nsamples, dtype=np.float64, always_2d=True - ) - if len(noise) != nsamples: - raise RuntimeError(f"Something wrong: {noise_path}") + + audio_data = librosa.load(noise_path, dtype='float32')[0][None, :] + frames = len(audio_data[0]) + if frames == nsamples: + noise = audio_data + elif frames < nsamples: + offset = np.random.randint(0, nsamples - frames) + # noise: (Time, Nmic) + noise = audio_data + # Repeat noise + noise = np.pad( + noise, + [(offset, nsamples - frames - offset), (0, 0)], + mode="wrap", + ) + else: + noise = audio_data[:, nsamples] + # offset = np.random.randint(0, frames - nsamples) + # f.seek(offset) + # noise: (Time, Nmic) + # noise = f.read( + # nsamples, dtype=np.float64, always_2d=True + # ) + # if len(noise) != nsamples: + # raise RuntimeError(f"Something wrong: {noise_path}") # noise: (Nmic, Time) noise = noise.T diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py index 0ebf325e0..01a8c6ff2 100644 --- a/funasr/datasets/small_datasets/preprocessor.py +++ b/funasr/datasets/small_datasets/preprocessor.py @@ -9,11 +9,11 @@ from typing import Union import numpy as np import scipy.signal -import soundfile +import librosa -from funasr.text.build_tokenizer import build_tokenizer -from funasr.text.cleaner import TextCleaner -from funasr.text.token_id_converter import TokenIDConverter +from funasr.tokenizer.build_tokenizer import build_tokenizer +from funasr.tokenizer.cleaner import TextCleaner +from funasr.tokenizer.token_id_converter import TokenIDConverter class AbsPreprocessor(ABC): @@ -275,7 +275,7 @@ class CommonPreprocessor(AbsPreprocessor): if self.rirs is not None and self.rir_apply_prob >= np.random.random(): rir_path = np.random.choice(self.rirs) if rir_path is not None: - rir, _ = soundfile.read( + rir, _ = librosa.load( rir_path, dtype=np.float64, always_2d=True ) @@ -301,28 +301,30 @@ class CommonPreprocessor(AbsPreprocessor): noise_db = np.random.uniform( self.noise_db_low, self.noise_db_high ) - with soundfile.SoundFile(noise_path) as f: - if f.frames == nsamples: - noise = f.read(dtype=np.float64, always_2d=True) - elif f.frames < nsamples: - offset = np.random.randint(0, nsamples - f.frames) - # noise: (Time, Nmic) - noise = f.read(dtype=np.float64, always_2d=True) - # Repeat noise - noise = np.pad( - noise, - [(offset, nsamples - f.frames - offset), (0, 0)], - mode="wrap", - ) - else: - offset = np.random.randint(0, f.frames - nsamples) - f.seek(offset) - # noise: (Time, Nmic) - noise = f.read( - nsamples, dtype=np.float64, always_2d=True - ) - if len(noise) != nsamples: - raise RuntimeError(f"Something wrong: {noise_path}") + audio_data = librosa.load(noise_path, dtype='float32')[0][None, :] + frames = len(audio_data[0]) + if frames == nsamples: + noise = audio_data + elif frames < nsamples: + offset = np.random.randint(0, nsamples - frames) + # noise: (Time, Nmic) + noise = audio_data + # Repeat noise + noise = np.pad( + noise, + [(offset, nsamples - frames - offset), (0, 0)], + mode="wrap", + ) + else: + noise = audio_data[:, nsamples] + # offset = np.random.randint(0, frames - nsamples) + # f.seek(offset) + # noise: (Time, Nmic) + # noise = f.read( + # nsamples, dtype=np.float64, always_2d=True + # ) + # if len(noise) != nsamples: + # raise RuntimeError(f"Something wrong: {noise_path}") # noise: (Nmic, Time) noise = noise.T diff --git a/funasr/export/export_conformer.py b/funasr/export/export_conformer.py deleted file mode 100644 index 4980775a8..000000000 --- a/funasr/export/export_conformer.py +++ /dev/null @@ -1,151 +0,0 @@ -import json -from typing import Union, Dict -from pathlib import Path - -import os -import logging -import torch - -from funasr.export.models import get_model -import numpy as np -import random -from funasr.utils.types import str2bool, str2triple_str -# torch_version = float(".".join(torch.__version__.split(".")[:2])) -# assert torch_version > 1.9 - -class ModelExport: - def __init__( - self, - cache_dir: Union[Path, str] = None, - onnx: bool = True, - device: str = "cpu", - quant: bool = True, - fallback_num: int = 0, - audio_in: str = None, - calib_num: int = 200, - model_revision: str = None, - ): - self.set_all_random_seed(0) - - self.cache_dir = cache_dir - self.export_config = dict( - feats_dim=560, - onnx=False, - ) - - self.onnx = onnx - self.device = device - self.quant = quant - self.fallback_num = fallback_num - self.frontend = None - self.audio_in = audio_in - self.calib_num = calib_num - self.model_revision = model_revision - - def _export( - self, - model, - model_dir: str = None, - verbose: bool = False, - ): - - export_dir = model_dir - os.makedirs(export_dir, exist_ok=True) - - self.export_config["model_name"] = "model" - model = get_model( - model, - self.export_config, - ) - model.eval() - - if self.onnx: - self._export_onnx(model, verbose, export_dir) - - print("output dir: {}".format(export_dir)) - - def _export_onnx(self, model, verbose, path): - model._export_onnx(verbose, path) - - def set_all_random_seed(self, seed: int): - random.seed(seed) - np.random.seed(seed) - torch.random.manual_seed(seed) - - def parse_audio_in(self, audio_in): - - wav_list, name_list = [], [] - if audio_in.endswith(".scp"): - f = open(audio_in, 'r') - lines = f.readlines()[:self.calib_num] - for line in lines: - name, path = line.strip().split() - name_list.append(name) - wav_list.append(path) - else: - wav_list = [audio_in,] - name_list = ["test",] - return wav_list, name_list - - def load_feats(self, audio_in: str = None): - import torchaudio - - wav_list, name_list = self.parse_audio_in(audio_in) - feats = [] - feats_len = [] - for line in wav_list: - path = line.strip() - waveform, sampling_rate = torchaudio.load(path) - if sampling_rate != self.frontend.fs: - waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, - new_freq=self.frontend.fs)(waveform) - fbank, fbank_len = self.frontend(waveform, [waveform.size(1)]) - feats.append(fbank) - feats_len.append(fbank_len) - return feats, feats_len - - def export(self, - mode: str = None, - ): - - if mode.startswith('conformer'): - from funasr.tasks.asr import ASRTask - config = os.path.join(model_dir, 'config.yaml') - model_file = os.path.join(model_dir, 'model.pb') - cmvn_file = os.path.join(model_dir, 'am.mvn') - model, asr_train_args = ASRTask.build_model_from_file( - config, model_file, cmvn_file, 'cpu' - ) - self.frontend = model.frontend - self.export_config["feats_dim"] = 560 - - self._export(model, self.cache_dir) - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - # parser.add_argument('--model-name', type=str, required=True) - parser.add_argument('--model-name', type=str, action="append", required=True, default=[]) - parser.add_argument('--export-dir', type=str, required=True) - parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') - parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]') - parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model') - parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') - parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') - parser.add_argument('--calib_num', type=int, default=200, help='calib max num') - parser.add_argument('--model_revision', type=str, default=None, help='model_revision') - args = parser.parse_args() - - export_model = ModelExport( - cache_dir=args.export_dir, - onnx=args.type == 'onnx', - device=args.device, - quant=args.quantize, - fallback_num=args.fallback_num, - audio_in=args.audio_in, - calib_num=args.calib_num, - model_revision=args.model_revision, - ) - for model_name in args.model_name: - print("export model: {}".format(model_name)) - export_model.export(model_name) diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index 94447dca3..b7b0889dc 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -1,7 +1,7 @@ from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export -from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export +# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export from funasr.models.e2e_vad import E2EVadModel from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export @@ -30,8 +30,8 @@ def get_model(model, export_config=None): return [encoder, decoder] elif isinstance(model, Paraformer): return Paraformer_export(model, **export_config) - elif isinstance(model, Conformer_export): - return Conformer_export(model, **export_config) + # elif isinstance(model, Conformer_export): + # return Conformer_export(model, **export_config) elif isinstance(model, E2EVadModel): return E2EVadModel_export(model, **export_config) elif isinstance(model, PunctuationModel): diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py deleted file mode 100644 index 45feda5fd..000000000 --- a/funasr/export/models/e2e_asr_conformer.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import logging -import torch -import torch.nn as nn - -from funasr.export.utils.torch_function import MakePadMask -from funasr.export.utils.torch_function import sequence_mask -from funasr.models.encoder.conformer_encoder import ConformerEncoder -from funasr.models.decoder.transformer_decoder import TransformerDecoder -from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export -from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export - -class Conformer(nn.Module): - """ - export conformer into onnx format - """ - - def __init__( - self, - model, - max_seq_len=512, - feats_dim=560, - model_name='model', - **kwargs, - ): - super().__init__() - onnx = False - if "onnx" in kwargs: - onnx = kwargs["onnx"] - if isinstance(model.encoder, ConformerEncoder): - self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) - elif isinstance(model.decoder, TransformerDecoder): - self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx) - - self.feats_dim = feats_dim - self.model_name = model_name - - if onnx: - self.make_pad_mask = MakePadMask(max_seq_len, flip=False) - else: - self.make_pad_mask = sequence_mask(max_seq_len, flip=False) - - def _export_model(self, model, verbose, path): - dummy_input = model.get_dummy_inputs() - model_script = model - model_path = os.path.join(path, f'{model.model_name}.onnx') - if not os.path.exists(model_path): - torch.onnx.export( - model_script, - dummy_input, - model_path, - verbose=verbose, - opset_version=14, - input_names=model.get_input_names(), - output_names=model.get_output_names(), - dynamic_axes=model.get_dynamic_axes() - ) - - def _export_encoder_onnx(self, verbose, path): - model_encoder = self.encoder - self._export_model(model_encoder, verbose, path) - - def _export_decoder_onnx(self, verbose, path): - model_decoder = self.decoder - self._export_model(model_decoder, verbose, path) - - def _export_onnx(self, verbose, path): - self._export_encoder_onnx(verbose, path) - self._export_decoder_onnx(verbose, path) \ No newline at end of file diff --git a/funasr/export/models/language_models/embed.py b/funasr/export/models/language_models/embed.py deleted file mode 100644 index 57748f2ea..000000000 --- a/funasr/export/models/language_models/embed.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Positional Encoding Module.""" - -import math - -import torch -import torch.nn as nn -from funasr.modules.embedding import ( - LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding, - ScaledPositionalEncoding, StreamPositionalEncoding) -from funasr.modules.subsampling import ( - Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, - Conv2dSubsampling8) -from funasr.modules.subsampling_without_posenc import \ - Conv2dSubsamplingWOPosEnc - -from funasr.export.models.language_models.subsampling import ( - OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6, - OnnxConv2dSubsampling8) - - -def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True): - if isinstance(pos_emb, LegacyRelPositionalEncoding): - return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache) - elif isinstance(pos_emb, ScaledPositionalEncoding): - return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache) - elif isinstance(pos_emb, RelPositionalEncoding): - return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache) - elif isinstance(pos_emb, PositionalEncoding): - return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache) - elif isinstance(pos_emb, StreamPositionalEncoding): - return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache) - elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or ( - isinstance(pos_emb, Conv2dSubsamplingWOPosEnc) - ): - return pos_emb - else: - raise ValueError("Embedding model is not supported.") - - -class Embedding(nn.Module): - def __init__(self, model, max_seq_len=512, use_cache=True): - super().__init__() - self.model = model - if not isinstance(model, nn.Embedding): - if isinstance(model, Conv2dSubsampling): - self.model = OnnxConv2dSubsampling(model) - self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) - elif isinstance(model, Conv2dSubsampling2): - self.model = OnnxConv2dSubsampling2(model) - self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) - elif isinstance(model, Conv2dSubsampling6): - self.model = OnnxConv2dSubsampling6(model) - self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) - elif isinstance(model, Conv2dSubsampling8): - self.model = OnnxConv2dSubsampling8(model) - self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len) - else: - self.model[-1] = get_pos_emb(model[-1], max_seq_len) - - def forward(self, x, mask=None): - if mask is None: - return self.model(x) - else: - return self.model(x, mask) - - -def _pre_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): - """Perform pre-hook in load_state_dict for backward compatibility. - - Note: - We saved self.pe until v.0.5.2 but we have omitted it later. - Therefore, we remove the item "pe" from `state_dict` for backward compatibility. - - """ - k = prefix + "pe" - if k in state_dict: - state_dict.pop(k) - - -class OnnxPositionalEncoding(torch.nn.Module): - """Positional encoding. - - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_seq_len (int): Maximum input length. - reverse (bool): Whether to reverse the input position. Only for - the class LegacyRelPositionalEncoding. We remove it in the current - class RelPositionalEncoding. - """ - - def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True): - """Construct an PositionalEncoding object.""" - super(OnnxPositionalEncoding, self).__init__() - self.d_model = model.d_model - self.reverse = reverse - self.max_seq_len = max_seq_len - self.xscale = math.sqrt(self.d_model) - self._register_load_state_dict_pre_hook(_pre_hook) - self.pe = model.pe - self.use_cache = use_cache - self.model = model - if self.use_cache: - self.extend_pe() - else: - self.div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - - def extend_pe(self): - """Reset the positional encodings.""" - pe_length = len(self.pe[0]) - if self.max_seq_len < pe_length: - self.pe = self.pe[:, : self.max_seq_len] - else: - self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len)) - self.pe = self.model.pe - - def _add_pe(self, x): - """Computes positional encoding""" - if self.reverse: - position = torch.arange( - x.size(1) - 1, -1, -1.0, dtype=torch.float32 - ).unsqueeze(1) - else: - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - - x = x * self.xscale - x[:, :, 0::2] += torch.sin(position * self.div_term) - x[:, :, 1::2] += torch.cos(position * self.div_term) - return x - - def forward(self, x: torch.Tensor): - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - """ - if self.use_cache: - x = x * self.xscale + self.pe[:, : x.size(1)] - else: - x = self._add_pe(x) - return x - - -class OnnxScaledPositionalEncoding(OnnxPositionalEncoding): - """Scaled positional encoding module. - - See Sec. 3.2 https://arxiv.org/abs/1809.08895 - - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_seq_len (int): Maximum input length. - - """ - - def __init__(self, model, max_seq_len=512, use_cache=True): - """Initialize class.""" - super().__init__(model, max_seq_len, use_cache=use_cache) - self.alpha = torch.nn.Parameter(torch.tensor(1.0)) - - def reset_parameters(self): - """Reset parameters.""" - self.alpha.data = torch.tensor(1.0) - - def _add_pe(self, x): - """Computes positional encoding""" - if self.reverse: - position = torch.arange( - x.size(1) - 1, -1, -1.0, dtype=torch.float32 - ).unsqueeze(1) - else: - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - - x = x * self.alpha - x[:, :, 0::2] += torch.sin(position * self.div_term) - x[:, :, 1::2] += torch.cos(position * self.div_term) - return x - - def forward(self, x): - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - - """ - if self.use_cache: - x = x + self.alpha * self.pe[:, : x.size(1)] - else: - x = self._add_pe(x) - return x - - -class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding): - """Relative positional encoding module (old version). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - See : Appendix B in https://arxiv.org/abs/1901.02860 - - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_seq_len (int): Maximum input length. - - """ - - def __init__(self, model, max_seq_len=512, use_cache=True): - """Initialize class.""" - super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache) - - def _get_pe(self, x): - """Computes positional encoding""" - if self.reverse: - position = torch.arange( - x.size(1) - 1, -1, -1.0, dtype=torch.float32 - ).unsqueeze(1) - else: - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - - pe = torch.zeros(x.shape) - pe[:, :, 0::2] += torch.sin(position * self.div_term) - pe[:, :, 1::2] += torch.cos(position * self.div_term) - return pe - - def forward(self, x): - """Compute positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Positional embedding tensor (1, time, `*`). - - """ - x = x * self.xscale - if self.use_cache: - pos_emb = self.pe[:, : x.size(1)] - else: - pos_emb = self._get_pe(x) - return x, pos_emb - - -class OnnxRelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module (new implementation). - Details can be found in https://github.com/espnet/espnet/pull/2816. - See : Appendix B in https://arxiv.org/abs/1901.02860 - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_seq_len (int): Maximum input length. - """ - - def __init__(self, model, max_seq_len=512, use_cache=True): - """Construct an PositionalEncoding object.""" - super(OnnxRelPositionalEncoding, self).__init__() - self.d_model = model.d_model - self.xscale = math.sqrt(self.d_model) - self.pe = None - self.use_cache = use_cache - if self.use_cache: - self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len)) - else: - self.div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - - def extend_pe(self, x): - """Reset the positional encodings.""" - if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i