funasr v2 setup (#1106)

* funasr v2 setup
This commit is contained in:
zhifu gao 2023-11-22 00:36:35 +08:00 committed by GitHub
parent 99a6d81160
commit b57b98364f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 417 additions and 1197 deletions

View File

@ -1,7 +1,10 @@
import argparse import argparse
import tqdm import tqdm
import codecs import codecs
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
def get_args(): def get_args():

View File

@ -6,7 +6,12 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
import numpy as np import numpy as np
import sys import sys

View File

@ -44,9 +44,9 @@ class Speech2Text:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb") >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -251,9 +251,9 @@ class Speech2TextParaformer:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -625,9 +625,9 @@ class Speech2TextParaformerOnline:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth") >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -876,9 +876,9 @@ class Speech2TextUniASR:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -1106,9 +1106,9 @@ class Speech2TextMFCCA:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -1637,9 +1637,9 @@ class Speech2TextSAASR:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -1885,9 +1885,9 @@ class Speech2TextWhisper:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb") >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]

View File

@ -20,7 +20,8 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import soundfile # import librosa
import librosa
import yaml import yaml
from funasr.bin.asr_infer import Speech2Text from funasr.bin.asr_infer import Speech2Text
@ -1281,7 +1282,8 @@ def inference_paraformer_online(
try: try:
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
except: except:
raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0] # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
if raw_inputs.ndim == 2: if raw_inputs.ndim == 2:
raw_inputs = raw_inputs[:, 0] raw_inputs = raw_inputs[:, 0]
raw_inputs = torch.tensor(raw_inputs) raw_inputs = torch.tensor(raw_inputs)

View File

@ -27,11 +27,11 @@ class Speech2DiarizationEEND:
"""Speech2Diarlization class """Speech2Diarlization class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> import numpy as np >>> import numpy as np
>>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb") >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy") >>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile) >>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...} {"spk1": [(int, int), ...], ...}
@ -109,11 +109,11 @@ class Speech2DiarizationSOND:
"""Speech2Xvector class """Speech2Xvector class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> import numpy as np >>> import numpy as np
>>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb") >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy") >>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile) >>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...} {"spk1": [(int, int), ...], ...}

View File

@ -15,7 +15,8 @@ from typing import Tuple
from typing import Union from typing import Union
import numpy as np import numpy as np
import soundfile # import librosa
import librosa
import torch import torch
from scipy.signal import medfilt from scipy.signal import medfilt
@ -144,7 +145,9 @@ def inference_sond(
# read waveform file # read waveform file
example = [load_bytes(x) if isinstance(x, bytes) else x example = [load_bytes(x) if isinstance(x, bytes) else x
for x in example] for x in example]
example = [soundfile.read(x)[0] if isinstance(x, str) else x # example = [librosa.load(x)[0] if isinstance(x, str) else x
# for x in example]
example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
for x in example] for x in example]
# convert torch tensor to numpy array # convert torch tensor to numpy array
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x example = [x.numpy() if isinstance(example[0], torch.Tensor) else x

View File

@ -20,9 +20,9 @@ class SpeechSeparator:
"""SpeechSeparator class """SpeechSeparator class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt") >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> separated_wavs = speech_separator(audio) >>> separated_wavs = speech_separator(audio)
""" """

View File

@ -13,7 +13,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
import soundfile as sf import librosa
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse from funasr.utils import config_argparse
@ -104,7 +104,12 @@ def inference_ss(
ss_results = speech_separator(**batch) ss_results = speech_separator(**batch)
for spk in range(num_spks): for spk in range(num_spks):
sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate) # sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
try:
librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
except:
print("To write wav by librosa, you should install librosa<=0.8.0")
raise
torch.cuda.empty_cache() torch.cuda.empty_cache()
return ss_results return ss_results

View File

@ -22,9 +22,9 @@ class Speech2Xvector:
"""Speech2Xvector class """Speech2Xvector class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb") >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2xvector(audio) >>> speech2xvector(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]

View File

@ -23,9 +23,9 @@ class Speech2VadSegment:
"""Speech2VadSegment class """Speech2VadSegment class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt") >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio) >>> speech2segment(audio)
[[10, 230], [245, 450], ...] [[10, 230], [245, 450], ...]
@ -118,9 +118,9 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class """Speech2VadSegmentOnline class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt") >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio) >>> speech2segment(audio)
[[10, 230], [245, 450], ...] [[10, 230], [245, 450], ...]

View File

@ -246,14 +246,11 @@ class Trainer:
for iepoch in range(start_epoch, trainer_options.max_epoch + 1): for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch: if iepoch != start_epoch:
logging.info( logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format( "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch, iepoch,
trainer_options.max_epoch, trainer_options.max_epoch,
humanfriendly.format_timespan( (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
(time.perf_counter() - start_time) trainer_options.max_epoch - iepoch + 1),
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
) )
) )
else: else:

View File

@ -0,0 +1,60 @@
import torch
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset=None, args=None, drop_last=True, ):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.batch_size_type = args.batch_size_type
self.batch_size = args.batch_size
self.sort_size = args.sort_size
self.max_length_token = args.max_length_token
self.total_samples = len(dataset)
def __len__(self):
return self.total_samples
def __iter__(self):
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples-1) // self.sort_size + 1
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
idx = iter * self.sort_size + i
if idx >= self.total_samples:
continue
if self.batch_size_type == "example":
sample_len_cur = 1
else:
idx_map = self.dataset.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
self.dataset.indexed_dataset[idx_map]["target_len"]
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur = item
if sample_len_cur > self.max_length_token:
continue
max_token_cur = max(max_token, sample_len_cur)
max_token_padding = (1 + num_sample) * max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
max_token = sample_len_cur
num_sample = 1
batch = [idx]

View File

@ -16,8 +16,10 @@ from typing import Dict
from typing import Mapping from typing import Mapping
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
try:
import h5py import h5py
except:
print("If you want use h5py dataset, please pip install h5py, and try it again")
import humanfriendly import humanfriendly
import kaldiio import kaldiio
import numpy as np import numpy as np

View File

@ -0,0 +1,43 @@
import torch
import json
import torch.distributed as dist
class AudioDatasetJsonl(torch.utils.data.Dataset):
def __init__(self, path, data_parallel_rank=0, data_parallel_size=1):
super().__init__()
data_parallel_size = dist.get_world_size()
contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
data = json.loads(line.strip())
if "text" in data: # for sft
self.contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data["prompt"]
source = data["source"]
target = data["target"]
source_len = data["source_len"]
target_len = data["target_len"]
contents.append({"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
)
self.contents = []
total_num = len(contents)
num_per_rank = total_num // data_parallel_size
rank = dist.get_rank()
# import ipdb; ipdb.set_trace()
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
return self.contents[index]

View File

@ -14,7 +14,8 @@ import kaldiio
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import soundfile # import librosa
import librosa
from torch.utils.data.dataset import IterableDataset from torch.utils.data.dataset import IterableDataset
import os.path import os.path
@ -70,7 +71,8 @@ def load_wav(input):
try: try:
return torchaudio.load(input)[0].numpy() return torchaudio.load(input)[0].numpy()
except: except:
waveform, _ = soundfile.read(input, dtype='float32') # waveform, _ = librosa.load(input, dtype='float32')
waveform, _ = librosa.load(input, dtype='float32')
if waveform.ndim == 2: if waveform.ndim == 2:
waveform = waveform[:, 0] waveform = waveform[:, 0]
return np.expand_dims(waveform, axis=0) return np.expand_dims(waveform, axis=0)

View File

@ -7,7 +7,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torchaudio import torchaudio
import numpy as np import numpy as np
import soundfile # import librosa
import librosa
from kaldiio import ReadHelper from kaldiio import ReadHelper
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
@ -128,7 +129,8 @@ class AudioDataset(IterableDataset):
try: try:
waveform, sampling_rate = torchaudio.load(path) waveform, sampling_rate = torchaudio.load(path)
except: except:
waveform, sampling_rate = soundfile.read(path, dtype='float32') # waveform, sampling_rate = librosa.load(path, dtype='float32')
waveform, sampling_rate = librosa.load(path, dtype='float32')
if waveform.ndim == 2: if waveform.ndim == 2:
waveform = waveform[:, 0] waveform = waveform[:, 0]
waveform = np.expand_dims(waveform, axis=0) waveform = np.expand_dims(waveform, axis=0)

View File

@ -10,7 +10,7 @@ from typing import Union
import numpy as np import numpy as np
import scipy.signal import scipy.signal
import soundfile import librosa
import jieba import jieba
from funasr.text.build_tokenizer import build_tokenizer from funasr.text.build_tokenizer import build_tokenizer
@ -284,7 +284,7 @@ class CommonPreprocessor(AbsPreprocessor):
if self.rirs is not None and self.rir_apply_prob >= np.random.random(): if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs) rir_path = np.random.choice(self.rirs)
if rir_path is not None: if rir_path is not None:
rir, _ = soundfile.read( rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True rir_path, dtype=np.float64, always_2d=True
) )
@ -310,28 +310,31 @@ class CommonPreprocessor(AbsPreprocessor):
noise_db = np.random.uniform( noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high self.noise_db_low, self.noise_db_high
) )
with soundfile.SoundFile(noise_path) as f:
if f.frames == nsamples: audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
noise = f.read(dtype=np.float64, always_2d=True) frames = len(audio_data[0])
elif f.frames < nsamples: if frames == nsamples:
offset = np.random.randint(0, nsamples - f.frames) noise = audio_data
# noise: (Time, Nmic) elif frames < nsamples:
noise = f.read(dtype=np.float64, always_2d=True) offset = np.random.randint(0, nsamples - frames)
# Repeat noise # noise: (Time, Nmic)
noise = np.pad( noise = audio_data
noise, # Repeat noise
[(offset, nsamples - f.frames - offset), (0, 0)], noise = np.pad(
mode="wrap", noise,
) [(offset, nsamples - frames - offset), (0, 0)],
else: mode="wrap",
offset = np.random.randint(0, f.frames - nsamples) )
f.seek(offset) else:
# noise: (Time, Nmic) noise = audio_data[:, nsamples]
noise = f.read( # offset = np.random.randint(0, frames - nsamples)
nsamples, dtype=np.float64, always_2d=True # f.seek(offset)
) # noise: (Time, Nmic)
if len(noise) != nsamples: # noise = f.read(
raise RuntimeError(f"Something wrong: {noise_path}") # nsamples, dtype=np.float64, always_2d=True
# )
# if len(noise) != nsamples:
# raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time) # noise: (Nmic, Time)
noise = noise.T noise = noise.T

View File

@ -9,7 +9,7 @@ from typing import Union
import numpy as np import numpy as np
import scipy.signal import scipy.signal
import soundfile import librosa
from funasr.text.build_tokenizer import build_tokenizer from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner from funasr.text.cleaner import TextCleaner
@ -275,7 +275,7 @@ class CommonPreprocessor(AbsPreprocessor):
if self.rirs is not None and self.rir_apply_prob >= np.random.random(): if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs) rir_path = np.random.choice(self.rirs)
if rir_path is not None: if rir_path is not None:
rir, _ = soundfile.read( rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True rir_path, dtype=np.float64, always_2d=True
) )
@ -301,28 +301,30 @@ class CommonPreprocessor(AbsPreprocessor):
noise_db = np.random.uniform( noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high self.noise_db_low, self.noise_db_high
) )
with soundfile.SoundFile(noise_path) as f: audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
if f.frames == nsamples: frames = len(audio_data[0])
noise = f.read(dtype=np.float64, always_2d=True) if frames == nsamples:
elif f.frames < nsamples: noise = audio_data
offset = np.random.randint(0, nsamples - f.frames) elif frames < nsamples:
# noise: (Time, Nmic) offset = np.random.randint(0, nsamples - frames)
noise = f.read(dtype=np.float64, always_2d=True) # noise: (Time, Nmic)
# Repeat noise noise = audio_data
noise = np.pad( # Repeat noise
noise, noise = np.pad(
[(offset, nsamples - f.frames - offset), (0, 0)], noise,
mode="wrap", [(offset, nsamples - frames - offset), (0, 0)],
) mode="wrap",
else: )
offset = np.random.randint(0, f.frames - nsamples) else:
f.seek(offset) noise = audio_data[:, nsamples]
# noise: (Time, Nmic) # offset = np.random.randint(0, frames - nsamples)
noise = f.read( # f.seek(offset)
nsamples, dtype=np.float64, always_2d=True # noise: (Time, Nmic)
) # noise = f.read(
if len(noise) != nsamples: # nsamples, dtype=np.float64, always_2d=True
raise RuntimeError(f"Something wrong: {noise_path}") # )
# if len(noise) != nsamples:
# raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time) # noise: (Nmic, Time)
noise = noise.T noise = noise.T

View File

@ -1,151 +0,0 @@
import json
from typing import Union, Dict
from pathlib import Path
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ModelExport:
def __init__(
self,
cache_dir: Union[Path, str] = None,
onnx: bool = True,
device: str = "cpu",
quant: bool = True,
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
model_revision: str = None,
):
self.set_all_random_seed(0)
self.cache_dir = cache_dir
self.export_config = dict(
feats_dim=560,
onnx=False,
)
self.onnx = onnx
self.device = device
self.quant = quant
self.fallback_num = fallback_num
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
self.model_revision = model_revision
def _export(
self,
model,
model_dir: str = None,
verbose: bool = False,
):
export_dir = model_dir
os.makedirs(export_dir, exist_ok=True)
self.export_config["model_name"] = "model"
model = get_model(
model,
self.export_config,
)
model.eval()
if self.onnx:
self._export_onnx(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
def _export_onnx(self, model, verbose, path):
model._export_onnx(verbose, path)
def set_all_random_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def parse_audio_in(self, audio_in):
wav_list, name_list = [], []
if audio_in.endswith(".scp"):
f = open(audio_in, 'r')
lines = f.readlines()[:self.calib_num]
for line in lines:
name, path = line.strip().split()
name_list.append(name)
wav_list.append(path)
else:
wav_list = [audio_in,]
name_list = ["test",]
return wav_list, name_list
def load_feats(self, audio_in: str = None):
import torchaudio
wav_list, name_list = self.parse_audio_in(audio_in)
feats = []
feats_len = []
for line in wav_list:
path = line.strip()
waveform, sampling_rate = torchaudio.load(path)
if sampling_rate != self.frontend.fs:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend.fs)(waveform)
fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
feats.append(fbank)
feats_len.append(fbank_len)
return feats, feats_len
def export(self,
mode: str = None,
):
if mode.startswith('conformer'):
from funasr.tasks.asr import ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
self.export_config["feats_dim"] = 560
self._export(model, self.cache_dir)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
args = parser.parse_args()
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=args.type == 'onnx',
device=args.device,
quant=args.quantize,
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
model_revision=args.model_revision,
)
for model_name in args.model_name:
print("export model: {}".format(model_name))
export_model.export(model_name)

View File

@ -1,403 +0,0 @@
"""Positional Encoding Module."""
import math
import torch
import torch.nn as nn
from funasr.modules.embedding import (
LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
ScaledPositionalEncoding, StreamPositionalEncoding)
from funasr.modules.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
Conv2dSubsampling8)
from funasr.modules.subsampling_without_posenc import \
Conv2dSubsamplingWOPosEnc
from funasr.export.models.language_models.subsampling import (
OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
OnnxConv2dSubsampling8)
def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
if isinstance(pos_emb, LegacyRelPositionalEncoding):
return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, ScaledPositionalEncoding):
return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, RelPositionalEncoding):
return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, PositionalEncoding):
return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, StreamPositionalEncoding):
return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
):
return pos_emb
else:
raise ValueError("Embedding model is not supported.")
class Embedding(nn.Module):
def __init__(self, model, max_seq_len=512, use_cache=True):
super().__init__()
self.model = model
if not isinstance(model, nn.Embedding):
if isinstance(model, Conv2dSubsampling):
self.model = OnnxConv2dSubsampling(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling2):
self.model = OnnxConv2dSubsampling2(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling6):
self.model = OnnxConv2dSubsampling6(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling8):
self.model = OnnxConv2dSubsampling8(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
else:
self.model[-1] = get_pos_emb(model[-1], max_seq_len)
def forward(self, x, mask=None):
if mask is None:
return self.model(x)
else:
return self.model(x, mask)
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class OnnxPositionalEncoding(torch.nn.Module):
"""Positional encoding.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
reverse (bool): Whether to reverse the input position. Only for
the class LegacyRelPositionalEncoding. We remove it in the current
class RelPositionalEncoding.
"""
def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
"""Construct an PositionalEncoding object."""
super(OnnxPositionalEncoding, self).__init__()
self.d_model = model.d_model
self.reverse = reverse
self.max_seq_len = max_seq_len
self.xscale = math.sqrt(self.d_model)
self._register_load_state_dict_pre_hook(_pre_hook)
self.pe = model.pe
self.use_cache = use_cache
self.model = model
if self.use_cache:
self.extend_pe()
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
def extend_pe(self):
"""Reset the positional encodings."""
pe_length = len(self.pe[0])
if self.max_seq_len < pe_length:
self.pe = self.pe[:, : self.max_seq_len]
else:
self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
self.pe = self.model.pe
def _add_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.xscale
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
x = x * self.xscale + self.pe[:, : x.size(1)]
else:
x = self._add_pe(x)
return x
class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
"""Scaled positional encoding module.
See Sec. 3.2 https://arxiv.org/abs/1809.08895
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Initialize class."""
super().__init__(model, max_seq_len, use_cache=use_cache)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def _add_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.alpha
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
x = x + self.alpha * self.pe[:, : x.size(1)]
else:
x = self._add_pe(x)
return x
class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Initialize class."""
super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
def _get_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe = torch.zeros(x.shape)
pe[:, :, 0::2] += torch.sin(position * self.div_term)
pe[:, :, 1::2] += torch.cos(position * self.div_term)
return pe
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
if self.use_cache:
pos_emb = self.pe[:, : x.size(1)]
else:
pos_emb = self._get_pe(x)
return x, pos_emb
class OnnxRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Construct an PositionalEncoding object."""
super(OnnxRelPositionalEncoding, self).__init__()
self.d_model = model.d_model
self.xscale = math.sqrt(self.d_model)
self.pe = None
self.use_cache = use_cache
if self.use_cache:
self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def _get_pe(self, x):
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
theta = (
torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
)
pe_positive[:, 0::2] = torch.sin(theta)
pe_positive[:, 1::2] = torch.cos(theta)
pe_negative[:, 0::2] = -1 * torch.sin(theta)
pe_negative[:, 1::2] = torch.cos(theta)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
return torch.cat([pe_positive, pe_negative], dim=1)
def forward(self, x: torch.Tensor, use_cache=True):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
x = x * self.xscale
if self.use_cache:
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
else:
pos_emb = self._get_pe(x)
return x, pos_emb
class OnnxStreamPositionalEncoding(torch.nn.Module):
"""Streaming Positional encoding."""
def __init__(self, model, max_seq_len=5000, use_cache=True):
"""Construct an PositionalEncoding object."""
super(StreamPositionalEncoding, self).__init__()
self.use_cache = use_cache
self.d_model = model.d_model
self.xscale = model.xscale
self.pe = model.pe
self.use_cache = use_cache
self.max_seq_len = max_seq_len
if self.use_cache:
self.extend_pe()
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self):
"""Reset the positional encodings."""
pe_length = len(self.pe[0])
if self.max_seq_len < pe_length:
self.pe = self.pe[:, : self.max_seq_len]
else:
self.model.extend_pe(self.max_seq_len)
self.pe = self.model.pe
def _add_pe(self, x, start_idx):
position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.xscale
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x: torch.Tensor, start_idx: int = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
else:
return self._add_pe(x, start_idx)

View File

@ -1,84 +0,0 @@
import os
import torch
import torch.nn as nn
class SequentialRNNLM(nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.encoder = model.encoder
self.rnn = model.rnn
self.rnn_type = model.rnn_type
self.decoder = model.decoder
self.nlayers = model.nlayers
self.nhid = model.nhid
self.model_name = "seq_rnnlm"
def forward(self, y, hidden1, hidden2=None):
# batch_score function.
emb = self.encoder(y)
if self.rnn_type == "LSTM":
output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
else:
output, hidden1 = self.rnn(emb, hidden1)
decoded = self.decoder(
output.contiguous().view(output.size(0) * output.size(1), output.size(2))
)
if self.rnn_type == "LSTM":
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
hidden2,
)
else:
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
)
def get_dummy_inputs(self):
tgt = torch.LongTensor([0, 1]).unsqueeze(0)
hidden = torch.randn(self.nlayers, 1, self.nhid)
if self.rnn_type == "LSTM":
return (tgt, hidden, hidden)
else:
return (tgt, hidden)
def get_input_names(self):
if self.rnn_type == "LSTM":
return ["x", "in_hidden1", "in_hidden2"]
else:
return ["x", "in_hidden1"]
def get_output_names(self):
if self.rnn_type == "LSTM":
return ["y", "out_hidden1", "out_hidden2"]
else:
return ["y", "out_hidden1"]
def get_dynamic_axes(self):
ret = {
"x": {0: "x_batch", 1: "x_length"},
"y": {0: "y_batch"},
"in_hidden1": {1: "hidden1_batch"},
"out_hidden1": {1: "out_hidden1_batch"},
}
if self.rnn_type == "LSTM":
ret.update(
{
"in_hidden2": {1: "hidden2_batch"},
"out_hidden2": {1: "out_hidden2_batch"},
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "SequentialRNNLM",
"rnn_type": self.rnn_type,
"nhid": self.nhid,
"nlayers": self.nlayers,
}

View File

@ -1,185 +0,0 @@
"""Subsampling layer definition."""
import torch
class OnnxConv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class OnnxConv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:1]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class OnnxConv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-4:3]
class OnnxConv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]

View File

@ -1,110 +0,0 @@
import os
import torch
import torch.nn as nn
from funasr.modules.vgg2l import import VGG2L
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.subsampling import (
Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
from funasr.export.models.language_models.embed import Embedding
from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
from funasr.export.utils.torch_function import MakePadMask
class TransformerLM(nn.Module, AbsExportModel):
def __init__(self, model, max_seq_len=512, **kwargs):
super().__init__()
self.embed = Embedding(model.embed, max_seq_len)
self.encoder = model.encoder
self.decoder = model.decoder
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
# replace multihead attention module into customized module.
for i, d in enumerate(self.encoder.encoders):
# d is EncoderLayer
if isinstance(d.self_attn, MultiHeadedAttention):
d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
self.encoder.encoders[i] = OnnxEncoderLayer(d)
self.model_name = "transformer_lm"
self.num_heads = self.encoder.encoders[0].self_attn.h
self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
if len(mask.shape) == 2:
mask = mask[:, None, None, :]
elif len(mask.shape) == 3:
mask = mask[:, None, :]
mask = 1 - mask
return mask * -10000.0
def forward(self, y, cache):
feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
mask = self.make_pad_mask(feats_length) # (B, T)
mask = (y != 0) * mask
xs = self.embed(y)
# forward_one_step of Encoder
if isinstance(
self.encoder.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
):
xs, mask = self.encoder.embed(xs, mask)
else:
xs = self.encoder.embed(xs)
new_cache = []
mask = self.prepare_mask(mask)
for c, e in zip(cache, self.encoder.encoders):
xs, mask = e(xs, mask, c)
new_cache.append(xs)
if self.encoder.normalize_before:
xs = self.encoder.after_norm(xs)
h = self.decoder(xs[:, -1])
return h, new_cache
def get_dummy_inputs(self):
tgt = torch.LongTensor([1]).unsqueeze(0)
cache = [
torch.zeros((1, 1, self.encoder.encoders[0].size))
for _ in range(len(self.encoder.encoders))
]
return (tgt, cache)
def is_optimizable(self):
return True
def get_input_names(self):
return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
def get_output_names(self):
return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
def get_dynamic_axes(self):
ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
ret.update(
{
"cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
for d in range(len(self.encoder.encoders))
}
)
ret.update(
{
"out_cache_%d"
% d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
for d in range(len(self.encoder.encoders))
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "TransformerLM",
"odim": self.encoder.encoders[0].size,
"nlayers": len(self.encoder.encoders),
}

View File

@ -4,7 +4,7 @@ from typing import List, Tuple, Union
import random import random
import numpy as np import numpy as np
import soundfile import librosa
import librosa import librosa
import torch import torch
@ -116,7 +116,7 @@ class SoundScpReader(collections.abc.Mapping):
def __getitem__(self, key): def __getitem__(self, key):
wav = self.data[key] wav = self.data[key]
if self.normalize: if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given # librosa.load normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load( array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=self.always_2d wav, sr=self.dest_sample_rate, mono=self.always_2d
) )

View File

@ -5,8 +5,12 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from torch_complex import functional as FC try:
from torch_complex.tensor import ComplexTensor from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
EPS = torch.finfo(torch.double).eps EPS = torch.finfo(torch.double).eps

View File

@ -4,8 +4,11 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from torch_complex.tensor import ComplexTensor
try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex from funasr.layers.complex_utils import is_complex
from funasr.layers.inversible_interface import InversibleInterface from funasr.layers.inversible_interface import InversibleInterface

View File

@ -1,8 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try:
from rotary_embedding_torch import RotaryEmbedding from rotary_embedding_torch import RotaryEmbedding
except:
print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM from funasr.modules.mossformer import FLASH_ShareA_FFConvM

View File

@ -6,12 +6,15 @@ import logging
import humanfriendly import humanfriendly
import numpy as np import numpy as np
import torch import torch
from torch_complex.tensor import ComplexTensor try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.layers.log_mel import LogMel from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask

View File

@ -4,12 +4,12 @@ from typing import Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from funasr.modules.frontends.beamformer import apply_beamforming_vector from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
from funasr.modules.frontends.beamformer import get_mvdr_vector from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
from funasr.modules.frontends.beamformer import ( from funasr.models.frontend.frontends_utils.beamformer import (
get_power_spectral_density_matrix, # noqa: H301 get_power_spectral_density_matrix, # noqa: H301
) )
from funasr.modules.frontends.mask_estimator import MaskEstimator from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor from torch_complex.tensor import ComplexTensor

View File

@ -4,7 +4,7 @@ from pytorch_wpe import wpe_one_iteration
import torch import torch
from torch_complex.tensor import ComplexTensor from torch_complex.tensor import ComplexTensor
from funasr.modules.frontends.mask_estimator import MaskEstimator from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask

View File

@ -8,8 +8,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch_complex.tensor import ComplexTensor from torch_complex.tensor import ComplexTensor
from funasr.modules.frontends.dnn_beamformer import DNN_Beamformer from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
from funasr.modules.frontends.dnn_wpe import DNN_WPE from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
class Frontend(nn.Module): class Frontend(nn.Module):

View File

@ -10,7 +10,7 @@ import humanfriendly
import torch import torch
from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.modules.nets_utils import pad_list from funasr.modules.nets_utils import pad_list
from funasr.utils.get_default_kwargs import get_default_kwargs from funasr.utils.get_default_kwargs import get_default_kwargs

View File

@ -9,7 +9,7 @@ import os
import sys import sys
import numpy as np import numpy as np
import subprocess import subprocess
import soundfile as sf import librosa as sf
import io import io
from functools import lru_cache from functools import lru_cache
@ -67,18 +67,18 @@ def load_wav(wav_rxfilename, start=0, end=None):
# input piped command # input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True, p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE) stdout=subprocess.PIPE)
data, samplerate = sf.read(io.BytesIO(p.stdout.read()), data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
dtype='float32') dtype='float32')
# cannot seek # cannot seek
data = data[start:end] data = data[start:end]
elif wav_rxfilename == '-': elif wav_rxfilename == '-':
# stdin # stdin
data, samplerate = sf.read(sys.stdin, dtype='float32') data, samplerate = sf.load(sys.stdin, dtype='float32')
# cannot seek # cannot seek
data = data[start:end] data = data[start:end]
else: else:
# normal wav file # normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) data, samplerate = sf.load(wav_rxfilename, start=start, stop=end)
return data, samplerate return data, samplerate

View File

@ -278,14 +278,11 @@ class Trainer:
for iepoch in range(start_epoch, trainer_options.max_epoch + 1): for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch: if iepoch != start_epoch:
logging.info( logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format( "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch, iepoch,
trainer_options.max_epoch, trainer_options.max_epoch,
humanfriendly.format_timespan( (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
(time.perf_counter() - start_time) trainer_options.max_epoch - iepoch + 1),
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
) )
) )
else: else:

View File

@ -5,7 +5,7 @@ import struct
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import torchaudio import torchaudio
import soundfile import librosa
import numpy as np import numpy as np
import pkg_resources import pkg_resources
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@ -139,7 +139,7 @@ def get_sr_from_wav(fname: str):
try: try:
audio, fs = torchaudio.load(fname) audio, fs = torchaudio.load(fname)
except: except:
audio, fs = soundfile.read(fname) audio, fs = librosa.load(fname)
break break
if audio_type.rfind(".scp") >= 0: if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f: with open(fname, encoding="utf-8") as f:

View File

@ -5,7 +5,7 @@ from multiprocessing import Pool
import kaldiio import kaldiio
import numpy as np import numpy as np
import soundfile import librosa
import torch.distributed as dist import torch.distributed as dist
import torchaudio import torchaudio
@ -46,7 +46,7 @@ def wav2num_frame(wav_path, frontend_conf):
try: try:
waveform, sampling_rate = torchaudio.load(wav_path) waveform, sampling_rate = torchaudio.load(wav_path)
except: except:
waveform, sampling_rate = soundfile.read(wav_path) waveform, sampling_rate = librosa.load(wav_path)
waveform = np.expand_dims(waveform, axis=0) waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"]) n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"] feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]

View File

@ -12,7 +12,7 @@ import os
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import soundfile as sf import librosa as sf
import torch import torch
import torchaudio import torchaudio
import logging import logging
@ -43,7 +43,7 @@ def sv_preprocess(inputs: Union[np.ndarray, list]):
for i in range(len(inputs)): for i in range(len(inputs)):
if isinstance(inputs[i], str): if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i]) file_bytes = File.read(inputs[i])
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2: if len(data.shape) == 2:
data = data[:, 0] data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0) data = torch.from_numpy(data).unsqueeze(0)

View File

@ -3,7 +3,7 @@ import codecs
import logging import logging
import argparse import argparse
import numpy as np import numpy as np
import edit_distance # import edit_distance
from itertools import zip_longest from itertools import zip_longest
@ -160,112 +160,112 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
return res return res
class AverageShiftCalculator(): # class AverageShiftCalculator():
def __init__(self): # def __init__(self):
logging.warning("Calculating average shift.") # logging.warning("Calculating average shift.")
def __call__(self, file1, file2): # def __call__(self, file1, file2):
uttid_list1, ts_dict1 = self.read_timestamps(file1) # uttid_list1, ts_dict1 = self.read_timestamps(file1)
uttid_list2, ts_dict2 = self.read_timestamps(file2) # uttid_list2, ts_dict2 = self.read_timestamps(file2)
uttid_intersection = self._intersection(uttid_list1, uttid_list2) # uttid_intersection = self._intersection(uttid_list1, uttid_list2)
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2) # res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8])) # logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid)) # logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
def _intersection(self, list1, list2): # def _intersection(self, list1, list2):
set1 = set(list1) # set1 = set(list1)
set2 = set(list2) # set2 = set(list2)
if set1 == set2: # if set1 == set2:
logging.warning("Uttid same checked.") # logging.warning("Uttid same checked.")
return set1 # return set1
itsc = list(set1 & set2) # itsc = list(set1 & set2)
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc))) # logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
return itsc # return itsc
#
def read_timestamps(self, file): # def read_timestamps(self, file):
# read timestamps file in standard format # # read timestamps file in standard format
uttid_list = [] # uttid_list = []
ts_dict = {} # ts_dict = {}
with codecs.open(file, 'r') as fin: # with codecs.open(file, 'r') as fin:
for line in fin.readlines(): # for line in fin.readlines():
text = '' # text = ''
ts_list = [] # ts_list = []
line = line.rstrip() # line = line.rstrip()
uttid = line.split()[0] # uttid = line.split()[0]
uttid_list.append(uttid) # uttid_list.append(uttid)
body = " ".join(line.split()[1:]) # body = " ".join(line.split()[1:])
for pd in body.split(';'): # for pd in body.split(';'):
if not len(pd): continue # if not len(pd): continue
# pdb.set_trace() # # pdb.set_trace()
char, start, end = pd.lstrip(" ").split(' ') # char, start, end = pd.lstrip(" ").split(' ')
text += char + ',' # text += char + ','
ts_list.append((float(start), float(end))) # ts_list.append((float(start), float(end)))
# ts_lists.append(ts_list) # # ts_lists.append(ts_list)
ts_dict[uttid] = (text[:-1], ts_list) # ts_dict[uttid] = (text[:-1], ts_list)
logging.warning("File {} read done.".format(file)) # logging.warning("File {} read done.".format(file))
return uttid_list, ts_dict # return uttid_list, ts_dict
#
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2): # def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
shift_time = 0 # shift_time = 0
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2): # for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1]) # shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
num_tokens = len(filtered_timestamp_list1) # num_tokens = len(filtered_timestamp_list1)
return shift_time, num_tokens # return shift_time, num_tokens
#
def as_cal(self, uttid_list, ts_dict1, ts_dict2): # # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# calculate average shift between timestamp1 and timestamp2 # # # calculate average shift between timestamp1 and timestamp2
# when characters differ, use edit distance alignment # # # when characters differ, use edit distance alignment
# and calculate the error between the same characters # # # and calculate the error between the same characters
self._accumlated_shift = 0 # # self._accumlated_shift = 0
self._accumlated_tokens = 0 # # self._accumlated_tokens = 0
self.max_shift = 0 # # self.max_shift = 0
self.max_shift_uttid = None # # self.max_shift_uttid = None
for uttid in uttid_list: # # for uttid in uttid_list:
(t1, ts1) = ts_dict1[uttid] # # (t1, ts1) = ts_dict1[uttid]
(t2, ts2) = ts_dict2[uttid] # # (t2, ts2) = ts_dict2[uttid]
_align, _align2, _align3 = [], [], [] # # _align, _align2, _align3 = [], [], []
fts1, fts2 = [], [] # # fts1, fts2 = [], []
_t1, _t2 = [], [] # # _t1, _t2 = [], []
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(',')) # # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
s = sm.get_opcodes() # # s = sm.get_opcodes()
for j in range(len(s)): # # for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert": # # if s[j][0] == "replace" or s[j][0] == "insert":
_align.append(0) # # _align.append(0)
if s[j][0] == "replace" or s[j][0] == "delete": # # if s[j][0] == "replace" or s[j][0] == "delete":
_align3.append(0) # # _align3.append(0)
elif s[j][0] == "equal": # # elif s[j][0] == "equal":
_align.append(1) # # _align.append(1)
_align3.append(1) # # _align3.append(1)
else: # # else:
continue # # continue
# use s to index t2 # # # use s to index t2
for a, ts , t in zip(_align, ts2, t2.split(',')): # # for a, ts , t in zip(_align, ts2, t2.split(',')):
if a: # # if a:
fts2.append(ts) # # fts2.append(ts)
_t2.append(t) # # _t2.append(t)
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(',')) # # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
s = sm2.get_opcodes() # # s = sm2.get_opcodes()
for j in range(len(s)): # # for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert": # # if s[j][0] == "replace" or s[j][0] == "insert":
_align2.append(0) # # _align2.append(0)
elif s[j][0] == "equal": # # elif s[j][0] == "equal":
_align2.append(1) # # _align2.append(1)
else: # # else:
continue # # continue
# use s2 tp index t1 # # # use s2 tp index t1
for a, ts, t in zip(_align3, ts1, t1.split(',')): # # for a, ts, t in zip(_align3, ts1, t1.split(',')):
if a: # # if a:
fts1.append(ts) # # fts1.append(ts)
_t1.append(t) # # _t1.append(t)
if len(fts1) == len(fts2): # # if len(fts1) == len(fts2):
shift_time, num_tokens = self._shift(fts1, fts2) # # shift_time, num_tokens = self._shift(fts1, fts2)
self._accumlated_shift += shift_time # # self._accumlated_shift += shift_time
self._accumlated_tokens += num_tokens # # self._accumlated_tokens += num_tokens
if shift_time/num_tokens > self.max_shift: # # if shift_time/num_tokens > self.max_shift:
self.max_shift = shift_time/num_tokens # # self.max_shift = shift_time/num_tokens
self.max_shift_uttid = uttid # # self.max_shift_uttid = uttid
else: # # else:
logging.warning("length mismatch") # # logging.warning("length mismatch")
return self._accumlated_shift / self._accumlated_tokens # # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file): def convert_external_alphas(alphas_file, text_file, output_file):
@ -311,10 +311,10 @@ SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args): def main(args):
if args.mode == 'cal_aas': # if args.mode == 'cal_aas':
asc = AverageShiftCalculator() # asc = AverageShiftCalculator()
asc(args.input, args.input2) # asc(args.input, args.input2)
elif args.mode == 'read_ext_alphas': if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output) convert_external_alphas(args.input, args.input2, args.output)
else: else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES)) logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))

View File

@ -11,7 +11,7 @@ import librosa
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import soundfile import librosa
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
@ -166,7 +166,7 @@ def compute_fbank(wav_file,
try: try:
waveform, audio_sr = torchaudio.load(wav_file) waveform, audio_sr = torchaudio.load(wav_file)
except: except:
waveform, audio_sr = soundfile.read(wav_file, dtype='float32') waveform, audio_sr = librosa.load(wav_file, dtype='float32')
if waveform.ndim == 2: if waveform.ndim == 2:
waveform = waveform[:, 0] waveform = waveform[:, 0]
waveform = torch.tensor(np.expand_dims(waveform, axis=0)) waveform = torch.tensor(np.expand_dims(waveform, axis=0))
@ -191,7 +191,7 @@ def wav2num_frame(wav_path, frontend_conf):
try: try:
waveform, sampling_rate = torchaudio.load(wav_path) waveform, sampling_rate = torchaudio.load(wav_path)
except: except:
waveform, sampling_rate = soundfile.read(wav_path) waveform, sampling_rate = librosa.load(wav_path)
waveform = torch.tensor(np.expand_dims(waveform, axis=0)) waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000. speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"]) n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])

View File

@ -1,8 +1,11 @@
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Union from typing import Union
try:
import ffmpeg
except:
print("Please Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.")
import ffmpeg
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -10,36 +10,36 @@ from setuptools import setup
requirements = { requirements = {
"install": [ "install": [
"setuptools>=38.5.1", # "setuptools>=38.5.1",
"humanfriendly", "humanfriendly",
"scipy>=1.4.1", "scipy>=1.4.1",
"librosa", "librosa",
"jamo", # For kss # "jamo", # For kss
"PyYAML>=5.1.2", "PyYAML>=5.1.2",
"soundfile>=0.12.1", # "soundfile>=0.12.1",
"h5py>=3.1.0", # "h5py>=3.1.0",
"kaldiio>=2.17.0", "kaldiio>=2.17.0",
"torch_complex", # "torch_complex",
"nltk>=3.4.5", # "nltk>=3.4.5",
# ASR # ASR
"sentencepiece", "sentencepiece", # train
"jieba", "jieba",
"rotary_embedding_torch", # "rotary_embedding_torch",
"ffmpeg", # "ffmpeg-python",
# TTS # TTS
"pypinyin>=0.44.0", # "pypinyin>=0.44.0",
"espnet_tts_frontend", # "espnet_tts_frontend",
# ENH # ENH
"pytorch_wpe", # "pytorch_wpe",
"editdistance>=0.5.2", "editdistance>=0.5.2",
"tensorboard", "tensorboard",
"g2p", # "g2p",
"nara_wpe", # "nara_wpe",
# PAI # PAI
"oss2", "oss2",
"edit-distance", # "edit-distance",
"textgrid", # "textgrid",
"protobuf", # "protobuf",
"tqdm", "tqdm",
"hdbscan", "hdbscan",
"umap", "umap",
@ -104,7 +104,7 @@ setup(
name="funasr", name="funasr",
version=version, version=version,
url="https://github.com/alibaba-damo-academy/FunASR.git", url="https://github.com/alibaba-damo-academy/FunASR.git",
author="Speech Lab of DAMO Academy, Alibaba Group", author="Speech Lab of Alibaba Group",
author_email="funasr@list.alibaba-inc.com", author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(), long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),