funasr v2 setup (#1106)

* funasr v2 setup
This commit is contained in:
zhifu gao 2023-11-22 00:36:35 +08:00 committed by GitHub
parent 99a6d81160
commit b57b98364f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 417 additions and 1197 deletions

View File

@ -1,7 +1,10 @@
import argparse
import tqdm
import codecs
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb
class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb
def get_args():

View File

@ -6,7 +6,12 @@ import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb
import numpy as np
import sys

View File

@ -44,9 +44,9 @@ class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@ -251,9 +251,9 @@ class Speech2TextParaformer:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@ -625,9 +625,9 @@ class Speech2TextParaformerOnline:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@ -876,9 +876,9 @@ class Speech2TextUniASR:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@ -1106,9 +1106,9 @@ class Speech2TextMFCCA:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@ -1637,9 +1637,9 @@ class Speech2TextSAASR:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
@ -1885,9 +1885,9 @@ class Speech2TextWhisper:
"""Speech2Text class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]

View File

@ -20,7 +20,8 @@ from typing import Union
import numpy as np
import torch
import torchaudio
import soundfile
# import librosa
import librosa
import yaml
from funasr.bin.asr_infer import Speech2Text
@ -1281,7 +1282,8 @@ def inference_paraformer_online(
try:
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
except:
raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
# raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
if raw_inputs.ndim == 2:
raw_inputs = raw_inputs[:, 0]
raw_inputs = torch.tensor(raw_inputs)

View File

@ -27,11 +27,11 @@ class Speech2DiarizationEEND:
"""Speech2Diarlization class
Examples:
>>> import soundfile
>>> import librosa
>>> import numpy as np
>>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}
@ -109,11 +109,11 @@ class Speech2DiarizationSOND:
"""Speech2Xvector class
Examples:
>>> import soundfile
>>> import librosa
>>> import numpy as np
>>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...}

View File

@ -15,7 +15,8 @@ from typing import Tuple
from typing import Union
import numpy as np
import soundfile
# import librosa
import librosa
import torch
from scipy.signal import medfilt
@ -144,7 +145,9 @@ def inference_sond(
# read waveform file
example = [load_bytes(x) if isinstance(x, bytes) else x
for x in example]
example = [soundfile.read(x)[0] if isinstance(x, str) else x
# example = [librosa.load(x)[0] if isinstance(x, str) else x
# for x in example]
example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
for x in example]
# convert torch tensor to numpy array
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x

View File

@ -20,9 +20,9 @@ class SpeechSeparator:
"""SpeechSeparator class
Examples:
>>> import soundfile
>>> import librosa
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> separated_wavs = speech_separator(audio)
"""

View File

@ -13,7 +13,7 @@ from typing import Union
import numpy as np
import torch
import soundfile as sf
import librosa
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
@ -104,7 +104,12 @@ def inference_ss(
ss_results = speech_separator(**batch)
for spk in range(num_spks):
sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
# sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
try:
librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
except:
print("To write wav by librosa, you should install librosa<=0.8.0")
raise
torch.cuda.empty_cache()
return ss_results

View File

@ -22,9 +22,9 @@ class Speech2Xvector:
"""Speech2Xvector class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2xvector(audio)
[(text, token, token_int, hypothesis object), ...]

View File

@ -23,9 +23,9 @@ class Speech2VadSegment:
"""Speech2VadSegment class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
@ -118,9 +118,9 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class
Examples:
>>> import soundfile
>>> import librosa
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]

View File

@ -246,14 +246,11 @@ class Trainer:
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch:
logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format(
"{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch,
trainer_options.max_epoch,
humanfriendly.format_timespan(
(time.perf_counter() - start_time)
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
(time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
trainer_options.max_epoch - iepoch + 1),
)
)
else:

View File

@ -0,0 +1,60 @@
import torch
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset=None, args=None, drop_last=True, ):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.batch_size_type = args.batch_size_type
self.batch_size = args.batch_size
self.sort_size = args.sort_size
self.max_length_token = args.max_length_token
self.total_samples = len(dataset)
def __len__(self):
return self.total_samples
def __iter__(self):
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples-1) // self.sort_size + 1
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
idx = iter * self.sort_size + i
if idx >= self.total_samples:
continue
if self.batch_size_type == "example":
sample_len_cur = 1
else:
idx_map = self.dataset.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
self.dataset.indexed_dataset[idx_map]["target_len"]
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur = item
if sample_len_cur > self.max_length_token:
continue
max_token_cur = max(max_token, sample_len_cur)
max_token_padding = (1 + num_sample) * max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
max_token = sample_len_cur
num_sample = 1
batch = [idx]

View File

@ -16,8 +16,10 @@ from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Union
import h5py
try:
import h5py
except:
print("If you want use h5py dataset, please pip install h5py, and try it again")
import humanfriendly
import kaldiio
import numpy as np

View File

@ -0,0 +1,43 @@
import torch
import json
import torch.distributed as dist
class AudioDatasetJsonl(torch.utils.data.Dataset):
def __init__(self, path, data_parallel_rank=0, data_parallel_size=1):
super().__init__()
data_parallel_size = dist.get_world_size()
contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
data = json.loads(line.strip())
if "text" in data: # for sft
self.contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data["prompt"]
source = data["source"]
target = data["target"]
source_len = data["source_len"]
target_len = data["target_len"]
contents.append({"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
)
self.contents = []
total_num = len(contents)
num_per_rank = total_num // data_parallel_size
rank = dist.get_rank()
# import ipdb; ipdb.set_trace()
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
return self.contents[index]

View File

@ -14,7 +14,8 @@ import kaldiio
import numpy as np
import torch
import torchaudio
import soundfile
# import librosa
import librosa
from torch.utils.data.dataset import IterableDataset
import os.path
@ -70,7 +71,8 @@ def load_wav(input):
try:
return torchaudio.load(input)[0].numpy()
except:
waveform, _ = soundfile.read(input, dtype='float32')
# waveform, _ = librosa.load(input, dtype='float32')
waveform, _ = librosa.load(input, dtype='float32')
if waveform.ndim == 2:
waveform = waveform[:, 0]
return np.expand_dims(waveform, axis=0)

View File

@ -7,7 +7,8 @@ import torch
import torch.distributed as dist
import torchaudio
import numpy as np
import soundfile
# import librosa
import librosa
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@ -128,7 +129,8 @@ class AudioDataset(IterableDataset):
try:
waveform, sampling_rate = torchaudio.load(path)
except:
waveform, sampling_rate = soundfile.read(path, dtype='float32')
# waveform, sampling_rate = librosa.load(path, dtype='float32')
waveform, sampling_rate = librosa.load(path, dtype='float32')
if waveform.ndim == 2:
waveform = waveform[:, 0]
waveform = np.expand_dims(waveform, axis=0)

View File

@ -10,7 +10,7 @@ from typing import Union
import numpy as np
import scipy.signal
import soundfile
import librosa
import jieba
from funasr.text.build_tokenizer import build_tokenizer
@ -284,7 +284,7 @@ class CommonPreprocessor(AbsPreprocessor):
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs)
if rir_path is not None:
rir, _ = soundfile.read(
rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True
)
@ -310,28 +310,31 @@ class CommonPreprocessor(AbsPreprocessor):
noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high
)
with soundfile.SoundFile(noise_path) as f:
if f.frames == nsamples:
noise = f.read(dtype=np.float64, always_2d=True)
elif f.frames < nsamples:
offset = np.random.randint(0, nsamples - f.frames)
# noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True)
# Repeat noise
noise = np.pad(
noise,
[(offset, nsamples - f.frames - offset), (0, 0)],
mode="wrap",
)
else:
offset = np.random.randint(0, f.frames - nsamples)
f.seek(offset)
# noise: (Time, Nmic)
noise = f.read(
nsamples, dtype=np.float64, always_2d=True
)
if len(noise) != nsamples:
raise RuntimeError(f"Something wrong: {noise_path}")
audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
frames = len(audio_data[0])
if frames == nsamples:
noise = audio_data
elif frames < nsamples:
offset = np.random.randint(0, nsamples - frames)
# noise: (Time, Nmic)
noise = audio_data
# Repeat noise
noise = np.pad(
noise,
[(offset, nsamples - frames - offset), (0, 0)],
mode="wrap",
)
else:
noise = audio_data[:, nsamples]
# offset = np.random.randint(0, frames - nsamples)
# f.seek(offset)
# noise: (Time, Nmic)
# noise = f.read(
# nsamples, dtype=np.float64, always_2d=True
# )
# if len(noise) != nsamples:
# raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time)
noise = noise.T

View File

@ -9,7 +9,7 @@ from typing import Union
import numpy as np
import scipy.signal
import soundfile
import librosa
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@ -275,7 +275,7 @@ class CommonPreprocessor(AbsPreprocessor):
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs)
if rir_path is not None:
rir, _ = soundfile.read(
rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True
)
@ -301,28 +301,30 @@ class CommonPreprocessor(AbsPreprocessor):
noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high
)
with soundfile.SoundFile(noise_path) as f:
if f.frames == nsamples:
noise = f.read(dtype=np.float64, always_2d=True)
elif f.frames < nsamples:
offset = np.random.randint(0, nsamples - f.frames)
# noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True)
# Repeat noise
noise = np.pad(
noise,
[(offset, nsamples - f.frames - offset), (0, 0)],
mode="wrap",
)
else:
offset = np.random.randint(0, f.frames - nsamples)
f.seek(offset)
# noise: (Time, Nmic)
noise = f.read(
nsamples, dtype=np.float64, always_2d=True
)
if len(noise) != nsamples:
raise RuntimeError(f"Something wrong: {noise_path}")
audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
frames = len(audio_data[0])
if frames == nsamples:
noise = audio_data
elif frames < nsamples:
offset = np.random.randint(0, nsamples - frames)
# noise: (Time, Nmic)
noise = audio_data
# Repeat noise
noise = np.pad(
noise,
[(offset, nsamples - frames - offset), (0, 0)],
mode="wrap",
)
else:
noise = audio_data[:, nsamples]
# offset = np.random.randint(0, frames - nsamples)
# f.seek(offset)
# noise: (Time, Nmic)
# noise = f.read(
# nsamples, dtype=np.float64, always_2d=True
# )
# if len(noise) != nsamples:
# raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time)
noise = noise.T

View File

@ -1,151 +0,0 @@
import json
from typing import Union, Dict
from pathlib import Path
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ModelExport:
def __init__(
self,
cache_dir: Union[Path, str] = None,
onnx: bool = True,
device: str = "cpu",
quant: bool = True,
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
model_revision: str = None,
):
self.set_all_random_seed(0)
self.cache_dir = cache_dir
self.export_config = dict(
feats_dim=560,
onnx=False,
)
self.onnx = onnx
self.device = device
self.quant = quant
self.fallback_num = fallback_num
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
self.model_revision = model_revision
def _export(
self,
model,
model_dir: str = None,
verbose: bool = False,
):
export_dir = model_dir
os.makedirs(export_dir, exist_ok=True)
self.export_config["model_name"] = "model"
model = get_model(
model,
self.export_config,
)
model.eval()
if self.onnx:
self._export_onnx(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
def _export_onnx(self, model, verbose, path):
model._export_onnx(verbose, path)
def set_all_random_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def parse_audio_in(self, audio_in):
wav_list, name_list = [], []
if audio_in.endswith(".scp"):
f = open(audio_in, 'r')
lines = f.readlines()[:self.calib_num]
for line in lines:
name, path = line.strip().split()
name_list.append(name)
wav_list.append(path)
else:
wav_list = [audio_in,]
name_list = ["test",]
return wav_list, name_list
def load_feats(self, audio_in: str = None):
import torchaudio
wav_list, name_list = self.parse_audio_in(audio_in)
feats = []
feats_len = []
for line in wav_list:
path = line.strip()
waveform, sampling_rate = torchaudio.load(path)
if sampling_rate != self.frontend.fs:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend.fs)(waveform)
fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
feats.append(fbank)
feats_len.append(fbank_len)
return feats, feats_len
def export(self,
mode: str = None,
):
if mode.startswith('conformer'):
from funasr.tasks.asr import ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
self.export_config["feats_dim"] = 560
self._export(model, self.cache_dir)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
args = parser.parse_args()
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=args.type == 'onnx',
device=args.device,
quant=args.quantize,
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
model_revision=args.model_revision,
)
for model_name in args.model_name:
print("export model: {}".format(model_name))
export_model.export(model_name)

View File

@ -1,403 +0,0 @@
"""Positional Encoding Module."""
import math
import torch
import torch.nn as nn
from funasr.modules.embedding import (
LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
ScaledPositionalEncoding, StreamPositionalEncoding)
from funasr.modules.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
Conv2dSubsampling8)
from funasr.modules.subsampling_without_posenc import \
Conv2dSubsamplingWOPosEnc
from funasr.export.models.language_models.subsampling import (
OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
OnnxConv2dSubsampling8)
def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
if isinstance(pos_emb, LegacyRelPositionalEncoding):
return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, ScaledPositionalEncoding):
return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, RelPositionalEncoding):
return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, PositionalEncoding):
return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, StreamPositionalEncoding):
return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
):
return pos_emb
else:
raise ValueError("Embedding model is not supported.")
class Embedding(nn.Module):
def __init__(self, model, max_seq_len=512, use_cache=True):
super().__init__()
self.model = model
if not isinstance(model, nn.Embedding):
if isinstance(model, Conv2dSubsampling):
self.model = OnnxConv2dSubsampling(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling2):
self.model = OnnxConv2dSubsampling2(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling6):
self.model = OnnxConv2dSubsampling6(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling8):
self.model = OnnxConv2dSubsampling8(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
else:
self.model[-1] = get_pos_emb(model[-1], max_seq_len)
def forward(self, x, mask=None):
if mask is None:
return self.model(x)
else:
return self.model(x, mask)
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class OnnxPositionalEncoding(torch.nn.Module):
"""Positional encoding.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
reverse (bool): Whether to reverse the input position. Only for
the class LegacyRelPositionalEncoding. We remove it in the current
class RelPositionalEncoding.
"""
def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
"""Construct an PositionalEncoding object."""
super(OnnxPositionalEncoding, self).__init__()
self.d_model = model.d_model
self.reverse = reverse
self.max_seq_len = max_seq_len
self.xscale = math.sqrt(self.d_model)
self._register_load_state_dict_pre_hook(_pre_hook)
self.pe = model.pe
self.use_cache = use_cache
self.model = model
if self.use_cache:
self.extend_pe()
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
def extend_pe(self):
"""Reset the positional encodings."""
pe_length = len(self.pe[0])
if self.max_seq_len < pe_length:
self.pe = self.pe[:, : self.max_seq_len]
else:
self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
self.pe = self.model.pe
def _add_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.xscale
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
x = x * self.xscale + self.pe[:, : x.size(1)]
else:
x = self._add_pe(x)
return x
class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
"""Scaled positional encoding module.
See Sec. 3.2 https://arxiv.org/abs/1809.08895
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Initialize class."""
super().__init__(model, max_seq_len, use_cache=use_cache)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def _add_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.alpha
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
x = x + self.alpha * self.pe[:, : x.size(1)]
else:
x = self._add_pe(x)
return x
class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Initialize class."""
super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
def _get_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe = torch.zeros(x.shape)
pe[:, :, 0::2] += torch.sin(position * self.div_term)
pe[:, :, 1::2] += torch.cos(position * self.div_term)
return pe
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
if self.use_cache:
pos_emb = self.pe[:, : x.size(1)]
else:
pos_emb = self._get_pe(x)
return x, pos_emb
class OnnxRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Construct an PositionalEncoding object."""
super(OnnxRelPositionalEncoding, self).__init__()
self.d_model = model.d_model
self.xscale = math.sqrt(self.d_model)
self.pe = None
self.use_cache = use_cache
if self.use_cache:
self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def _get_pe(self, x):
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
theta = (
torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
)
pe_positive[:, 0::2] = torch.sin(theta)
pe_positive[:, 1::2] = torch.cos(theta)
pe_negative[:, 0::2] = -1 * torch.sin(theta)
pe_negative[:, 1::2] = torch.cos(theta)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
return torch.cat([pe_positive, pe_negative], dim=1)
def forward(self, x: torch.Tensor, use_cache=True):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
x = x * self.xscale
if self.use_cache:
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
else:
pos_emb = self._get_pe(x)
return x, pos_emb
class OnnxStreamPositionalEncoding(torch.nn.Module):
"""Streaming Positional encoding."""
def __init__(self, model, max_seq_len=5000, use_cache=True):
"""Construct an PositionalEncoding object."""
super(StreamPositionalEncoding, self).__init__()
self.use_cache = use_cache
self.d_model = model.d_model
self.xscale = model.xscale
self.pe = model.pe
self.use_cache = use_cache
self.max_seq_len = max_seq_len
if self.use_cache:
self.extend_pe()
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self):
"""Reset the positional encodings."""
pe_length = len(self.pe[0])
if self.max_seq_len < pe_length:
self.pe = self.pe[:, : self.max_seq_len]
else:
self.model.extend_pe(self.max_seq_len)
self.pe = self.model.pe
def _add_pe(self, x, start_idx):
position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.xscale
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x: torch.Tensor, start_idx: int = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
else:
return self._add_pe(x, start_idx)

View File

@ -1,84 +0,0 @@
import os
import torch
import torch.nn as nn
class SequentialRNNLM(nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.encoder = model.encoder
self.rnn = model.rnn
self.rnn_type = model.rnn_type
self.decoder = model.decoder
self.nlayers = model.nlayers
self.nhid = model.nhid
self.model_name = "seq_rnnlm"
def forward(self, y, hidden1, hidden2=None):
# batch_score function.
emb = self.encoder(y)
if self.rnn_type == "LSTM":
output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
else:
output, hidden1 = self.rnn(emb, hidden1)
decoded = self.decoder(
output.contiguous().view(output.size(0) * output.size(1), output.size(2))
)
if self.rnn_type == "LSTM":
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
hidden2,
)
else:
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
)
def get_dummy_inputs(self):
tgt = torch.LongTensor([0, 1]).unsqueeze(0)
hidden = torch.randn(self.nlayers, 1, self.nhid)
if self.rnn_type == "LSTM":
return (tgt, hidden, hidden)
else:
return (tgt, hidden)
def get_input_names(self):
if self.rnn_type == "LSTM":
return ["x", "in_hidden1", "in_hidden2"]
else:
return ["x", "in_hidden1"]
def get_output_names(self):
if self.rnn_type == "LSTM":
return ["y", "out_hidden1", "out_hidden2"]
else:
return ["y", "out_hidden1"]
def get_dynamic_axes(self):
ret = {
"x": {0: "x_batch", 1: "x_length"},
"y": {0: "y_batch"},
"in_hidden1": {1: "hidden1_batch"},
"out_hidden1": {1: "out_hidden1_batch"},
}
if self.rnn_type == "LSTM":
ret.update(
{
"in_hidden2": {1: "hidden2_batch"},
"out_hidden2": {1: "out_hidden2_batch"},
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "SequentialRNNLM",
"rnn_type": self.rnn_type,
"nhid": self.nhid,
"nlayers": self.nlayers,
}

View File

@ -1,185 +0,0 @@
"""Subsampling layer definition."""
import torch
class OnnxConv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class OnnxConv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:1]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class OnnxConv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-4:3]
class OnnxConv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]

View File

@ -1,110 +0,0 @@
import os
import torch
import torch.nn as nn
from funasr.modules.vgg2l import import VGG2L
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.subsampling import (
Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
from funasr.export.models.language_models.embed import Embedding
from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
from funasr.export.utils.torch_function import MakePadMask
class TransformerLM(nn.Module, AbsExportModel):
def __init__(self, model, max_seq_len=512, **kwargs):
super().__init__()
self.embed = Embedding(model.embed, max_seq_len)
self.encoder = model.encoder
self.decoder = model.decoder
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
# replace multihead attention module into customized module.
for i, d in enumerate(self.encoder.encoders):
# d is EncoderLayer
if isinstance(d.self_attn, MultiHeadedAttention):
d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
self.encoder.encoders[i] = OnnxEncoderLayer(d)
self.model_name = "transformer_lm"
self.num_heads = self.encoder.encoders[0].self_attn.h
self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
if len(mask.shape) == 2:
mask = mask[:, None, None, :]
elif len(mask.shape) == 3:
mask = mask[:, None, :]
mask = 1 - mask
return mask * -10000.0
def forward(self, y, cache):
feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
mask = self.make_pad_mask(feats_length) # (B, T)
mask = (y != 0) * mask
xs = self.embed(y)
# forward_one_step of Encoder
if isinstance(
self.encoder.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
):
xs, mask = self.encoder.embed(xs, mask)
else:
xs = self.encoder.embed(xs)
new_cache = []
mask = self.prepare_mask(mask)
for c, e in zip(cache, self.encoder.encoders):
xs, mask = e(xs, mask, c)
new_cache.append(xs)
if self.encoder.normalize_before:
xs = self.encoder.after_norm(xs)
h = self.decoder(xs[:, -1])
return h, new_cache
def get_dummy_inputs(self):
tgt = torch.LongTensor([1]).unsqueeze(0)
cache = [
torch.zeros((1, 1, self.encoder.encoders[0].size))
for _ in range(len(self.encoder.encoders))
]
return (tgt, cache)
def is_optimizable(self):
return True
def get_input_names(self):
return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
def get_output_names(self):
return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
def get_dynamic_axes(self):
ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
ret.update(
{
"cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
for d in range(len(self.encoder.encoders))
}
)
ret.update(
{
"out_cache_%d"
% d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
for d in range(len(self.encoder.encoders))
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "TransformerLM",
"odim": self.encoder.encoders[0].size,
"nlayers": len(self.encoder.encoders),
}

View File

@ -4,7 +4,7 @@ from typing import List, Tuple, Union
import random
import numpy as np
import soundfile
import librosa
import librosa
import torch
@ -116,7 +116,7 @@ class SoundScpReader(collections.abc.Mapping):
def __getitem__(self, key):
wav = self.data[key]
if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given
# librosa.load normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=self.always_2d
)

View File

@ -5,8 +5,12 @@ from typing import Tuple
from typing import Union
import torch
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
try:
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
EPS = torch.finfo(torch.double).eps

View File

@ -4,8 +4,11 @@ from typing import Tuple
from typing import Union
import torch
from torch_complex.tensor import ComplexTensor
try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
from funasr.layers.inversible_interface import InversibleInterface

View File

@ -1,8 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
try:
from rotary_embedding_torch import RotaryEmbedding
except:
print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM

View File

@ -6,12 +6,15 @@ import logging
import humanfriendly
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask

View File

@ -4,12 +4,12 @@ from typing import Tuple
import torch
from torch.nn import functional as F
from funasr.modules.frontends.beamformer import apply_beamforming_vector
from funasr.modules.frontends.beamformer import get_mvdr_vector
from funasr.modules.frontends.beamformer import (
from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
from funasr.models.frontend.frontends_utils.beamformer import (
get_power_spectral_density_matrix, # noqa: H301
)
from funasr.modules.frontends.mask_estimator import MaskEstimator
from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor

View File

@ -4,7 +4,7 @@ from pytorch_wpe import wpe_one_iteration
import torch
from torch_complex.tensor import ComplexTensor
from funasr.modules.frontends.mask_estimator import MaskEstimator
from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from funasr.modules.nets_utils import make_pad_mask

View File

@ -8,8 +8,8 @@ import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
from funasr.modules.frontends.dnn_beamformer import DNN_Beamformer
from funasr.modules.frontends.dnn_wpe import DNN_WPE
from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
class Frontend(nn.Module):

View File

@ -10,7 +10,7 @@ import humanfriendly
import torch
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.modules.nets_utils import pad_list
from funasr.utils.get_default_kwargs import get_default_kwargs

View File

@ -9,7 +9,7 @@ import os
import sys
import numpy as np
import subprocess
import soundfile as sf
import librosa as sf
import io
from functools import lru_cache
@ -67,18 +67,18 @@ def load_wav(wav_rxfilename, start=0, end=None):
# input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE)
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
dtype='float32')
# cannot seek
data = data[start:end]
elif wav_rxfilename == '-':
# stdin
data, samplerate = sf.read(sys.stdin, dtype='float32')
data, samplerate = sf.load(sys.stdin, dtype='float32')
# cannot seek
data = data[start:end]
else:
# normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
data, samplerate = sf.load(wav_rxfilename, start=start, stop=end)
return data, samplerate

View File

@ -278,14 +278,11 @@ class Trainer:
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch:
logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format(
"{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch,
trainer_options.max_epoch,
humanfriendly.format_timespan(
(time.perf_counter() - start_time)
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
(time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
trainer_options.max_epoch - iepoch + 1),
)
)
else:

View File

@ -5,7 +5,7 @@ import struct
from typing import Any, Dict, List, Union
import torchaudio
import soundfile
import librosa
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@ -139,7 +139,7 @@ def get_sr_from_wav(fname: str):
try:
audio, fs = torchaudio.load(fname)
except:
audio, fs = soundfile.read(fname)
audio, fs = librosa.load(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:

View File

@ -5,7 +5,7 @@ from multiprocessing import Pool
import kaldiio
import numpy as np
import soundfile
import librosa
import torch.distributed as dist
import torchaudio
@ -46,7 +46,7 @@ def wav2num_frame(wav_path, frontend_conf):
try:
waveform, sampling_rate = torchaudio.load(wav_path)
except:
waveform, sampling_rate = soundfile.read(wav_path)
waveform, sampling_rate = librosa.load(wav_path)
waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]

View File

@ -12,7 +12,7 @@ import os
from typing import Any, Dict, List, Union
import numpy as np
import soundfile as sf
import librosa as sf
import torch
import torchaudio
import logging
@ -43,7 +43,7 @@ def sv_preprocess(inputs: Union[np.ndarray, list]):
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)

View File

@ -3,7 +3,7 @@ import codecs
import logging
import argparse
import numpy as np
import edit_distance
# import edit_distance
from itertools import zip_longest
@ -160,112 +160,112 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
return res
class AverageShiftCalculator():
def __init__(self):
logging.warning("Calculating average shift.")
def __call__(self, file1, file2):
uttid_list1, ts_dict1 = self.read_timestamps(file1)
uttid_list2, ts_dict2 = self.read_timestamps(file2)
uttid_intersection = self._intersection(uttid_list1, uttid_list2)
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
def _intersection(self, list1, list2):
set1 = set(list1)
set2 = set(list2)
if set1 == set2:
logging.warning("Uttid same checked.")
return set1
itsc = list(set1 & set2)
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
return itsc
def read_timestamps(self, file):
# read timestamps file in standard format
uttid_list = []
ts_dict = {}
with codecs.open(file, 'r') as fin:
for line in fin.readlines():
text = ''
ts_list = []
line = line.rstrip()
uttid = line.split()[0]
uttid_list.append(uttid)
body = " ".join(line.split()[1:])
for pd in body.split(';'):
if not len(pd): continue
# pdb.set_trace()
char, start, end = pd.lstrip(" ").split(' ')
text += char + ','
ts_list.append((float(start), float(end)))
# ts_lists.append(ts_list)
ts_dict[uttid] = (text[:-1], ts_list)
logging.warning("File {} read done.".format(file))
return uttid_list, ts_dict
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
shift_time = 0
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
num_tokens = len(filtered_timestamp_list1)
return shift_time, num_tokens
def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# calculate average shift between timestamp1 and timestamp2
# when characters differ, use edit distance alignment
# and calculate the error between the same characters
self._accumlated_shift = 0
self._accumlated_tokens = 0
self.max_shift = 0
self.max_shift_uttid = None
for uttid in uttid_list:
(t1, ts1) = ts_dict1[uttid]
(t2, ts2) = ts_dict2[uttid]
_align, _align2, _align3 = [], [], []
fts1, fts2 = [], []
_t1, _t2 = [], []
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
s = sm.get_opcodes()
for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert":
_align.append(0)
if s[j][0] == "replace" or s[j][0] == "delete":
_align3.append(0)
elif s[j][0] == "equal":
_align.append(1)
_align3.append(1)
else:
continue
# use s to index t2
for a, ts , t in zip(_align, ts2, t2.split(',')):
if a:
fts2.append(ts)
_t2.append(t)
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
s = sm2.get_opcodes()
for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert":
_align2.append(0)
elif s[j][0] == "equal":
_align2.append(1)
else:
continue
# use s2 tp index t1
for a, ts, t in zip(_align3, ts1, t1.split(',')):
if a:
fts1.append(ts)
_t1.append(t)
if len(fts1) == len(fts2):
shift_time, num_tokens = self._shift(fts1, fts2)
self._accumlated_shift += shift_time
self._accumlated_tokens += num_tokens
if shift_time/num_tokens > self.max_shift:
self.max_shift = shift_time/num_tokens
self.max_shift_uttid = uttid
else:
logging.warning("length mismatch")
return self._accumlated_shift / self._accumlated_tokens
# class AverageShiftCalculator():
# def __init__(self):
# logging.warning("Calculating average shift.")
# def __call__(self, file1, file2):
# uttid_list1, ts_dict1 = self.read_timestamps(file1)
# uttid_list2, ts_dict2 = self.read_timestamps(file2)
# uttid_intersection = self._intersection(uttid_list1, uttid_list2)
# res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
# logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
# logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
# def _intersection(self, list1, list2):
# set1 = set(list1)
# set2 = set(list2)
# if set1 == set2:
# logging.warning("Uttid same checked.")
# return set1
# itsc = list(set1 & set2)
# logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
# return itsc
#
# def read_timestamps(self, file):
# # read timestamps file in standard format
# uttid_list = []
# ts_dict = {}
# with codecs.open(file, 'r') as fin:
# for line in fin.readlines():
# text = ''
# ts_list = []
# line = line.rstrip()
# uttid = line.split()[0]
# uttid_list.append(uttid)
# body = " ".join(line.split()[1:])
# for pd in body.split(';'):
# if not len(pd): continue
# # pdb.set_trace()
# char, start, end = pd.lstrip(" ").split(' ')
# text += char + ','
# ts_list.append((float(start), float(end)))
# # ts_lists.append(ts_list)
# ts_dict[uttid] = (text[:-1], ts_list)
# logging.warning("File {} read done.".format(file))
# return uttid_list, ts_dict
#
# def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time = 0
# for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
# shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
# num_tokens = len(filtered_timestamp_list1)
# return shift_time, num_tokens
#
# # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# # # calculate average shift between timestamp1 and timestamp2
# # # when characters differ, use edit distance alignment
# # # and calculate the error between the same characters
# # self._accumlated_shift = 0
# # self._accumlated_tokens = 0
# # self.max_shift = 0
# # self.max_shift_uttid = None
# # for uttid in uttid_list:
# # (t1, ts1) = ts_dict1[uttid]
# # (t2, ts2) = ts_dict2[uttid]
# # _align, _align2, _align3 = [], [], []
# # fts1, fts2 = [], []
# # _t1, _t2 = [], []
# # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
# # s = sm.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align.append(0)
# # if s[j][0] == "replace" or s[j][0] == "delete":
# # _align3.append(0)
# # elif s[j][0] == "equal":
# # _align.append(1)
# # _align3.append(1)
# # else:
# # continue
# # # use s to index t2
# # for a, ts , t in zip(_align, ts2, t2.split(',')):
# # if a:
# # fts2.append(ts)
# # _t2.append(t)
# # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
# # s = sm2.get_opcodes()
# # for j in range(len(s)):
# # if s[j][0] == "replace" or s[j][0] == "insert":
# # _align2.append(0)
# # elif s[j][0] == "equal":
# # _align2.append(1)
# # else:
# # continue
# # # use s2 tp index t1
# # for a, ts, t in zip(_align3, ts1, t1.split(',')):
# # if a:
# # fts1.append(ts)
# # _t1.append(t)
# # if len(fts1) == len(fts2):
# # shift_time, num_tokens = self._shift(fts1, fts2)
# # self._accumlated_shift += shift_time
# # self._accumlated_tokens += num_tokens
# # if shift_time/num_tokens > self.max_shift:
# # self.max_shift = shift_time/num_tokens
# # self.max_shift_uttid = uttid
# # else:
# # logging.warning("length mismatch")
# # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
@ -311,10 +311,10 @@ SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
if args.mode == 'cal_aas':
asc = AverageShiftCalculator()
asc(args.input, args.input2)
elif args.mode == 'read_ext_alphas':
# if args.mode == 'cal_aas':
# asc = AverageShiftCalculator()
# asc(args.input, args.input2)
if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))

View File

@ -11,7 +11,7 @@ import librosa
import numpy as np
import torch
import torchaudio
import soundfile
import librosa
import torchaudio.compliance.kaldi as kaldi
@ -166,7 +166,7 @@ def compute_fbank(wav_file,
try:
waveform, audio_sr = torchaudio.load(wav_file)
except:
waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
waveform, audio_sr = librosa.load(wav_file, dtype='float32')
if waveform.ndim == 2:
waveform = waveform[:, 0]
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
@ -191,7 +191,7 @@ def wav2num_frame(wav_path, frontend_conf):
try:
waveform, sampling_rate = torchaudio.load(wav_path)
except:
waveform, sampling_rate = soundfile.read(wav_path)
waveform, sampling_rate = librosa.load(wav_path)
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])

View File

@ -1,8 +1,11 @@
import os
from functools import lru_cache
from typing import Union
try:
import ffmpeg
except:
print("Please Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.")
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F

View File

@ -10,36 +10,36 @@ from setuptools import setup
requirements = {
"install": [
"setuptools>=38.5.1",
# "setuptools>=38.5.1",
"humanfriendly",
"scipy>=1.4.1",
"librosa",
"jamo", # For kss
# "jamo", # For kss
"PyYAML>=5.1.2",
"soundfile>=0.12.1",
"h5py>=3.1.0",
# "soundfile>=0.12.1",
# "h5py>=3.1.0",
"kaldiio>=2.17.0",
"torch_complex",
"nltk>=3.4.5",
# "torch_complex",
# "nltk>=3.4.5",
# ASR
"sentencepiece",
"sentencepiece", # train
"jieba",
"rotary_embedding_torch",
"ffmpeg",
# "rotary_embedding_torch",
# "ffmpeg-python",
# TTS
"pypinyin>=0.44.0",
"espnet_tts_frontend",
# "pypinyin>=0.44.0",
# "espnet_tts_frontend",
# ENH
"pytorch_wpe",
# "pytorch_wpe",
"editdistance>=0.5.2",
"tensorboard",
"g2p",
"nara_wpe",
# "g2p",
# "nara_wpe",
# PAI
"oss2",
"edit-distance",
"textgrid",
"protobuf",
# "edit-distance",
# "textgrid",
# "protobuf",
"tqdm",
"hdbscan",
"umap",
@ -104,7 +104,7 @@ setup(
name="funasr",
version=version,
url="https://github.com/alibaba-damo-academy/FunASR.git",
author="Speech Lab of DAMO Academy, Alibaba Group",
author="Speech Lab of Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),