mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge remote-tracking branch 'refs/remotes/origin/main'
update contextual forward
This commit is contained in:
commit
adc88bd9e7
@ -1,7 +1,10 @@
|
||||
import argparse
|
||||
import tqdm
|
||||
import codecs
|
||||
import textgrid
|
||||
try:
|
||||
import textgrid
|
||||
except:
|
||||
raise "Please install textgrid firstly: pip install textgrid"
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
|
||||
@ -6,7 +6,10 @@ import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
try:
|
||||
import textgrid
|
||||
except:
|
||||
raise "Please install textgrid firstly: pip install textgrid"
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
|
||||
@ -6,7 +6,10 @@ import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
try:
|
||||
import textgrid
|
||||
except:
|
||||
raise "Please install textgrid firstly: pip install textgrid"
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
|
||||
@ -6,7 +6,10 @@ import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
try:
|
||||
import textgrid
|
||||
except:
|
||||
raise "Please install textgrid firstly: pip install textgrid"
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
|
||||
@ -6,7 +6,10 @@ import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
try:
|
||||
import textgrid
|
||||
except:
|
||||
raise "Please install textgrid firstly: pip install textgrid"
|
||||
import pdb
|
||||
|
||||
def get_args():
|
||||
|
||||
@ -6,7 +6,12 @@ import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
|
||||
try:
|
||||
import textgrid
|
||||
except:
|
||||
raise "Please install textgrid firstly: pip install textgrid"
|
||||
|
||||
import pdb
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
@ -34,8 +34,8 @@ from funasr.modules.beam_search.beam_search_transducer import Hypothesis as Hypo
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.build_utils.build_asr_model import frontend_choices
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
|
||||
@ -44,9 +44,9 @@ class Speech2Text:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -251,9 +251,9 @@ class Speech2TextParaformer:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -625,9 +625,9 @@ class Speech2TextParaformerOnline:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -876,9 +876,9 @@ class Speech2TextUniASR:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -1106,9 +1106,9 @@ class Speech2TextMFCCA:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -1637,9 +1637,9 @@ class Speech2TextSAASR:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
@ -1885,9 +1885,9 @@ class Speech2TextWhisper:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
|
||||
@ -20,7 +20,8 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
# import librosa
|
||||
import librosa
|
||||
import yaml
|
||||
|
||||
from funasr.bin.asr_infer import Speech2Text
|
||||
@ -1281,7 +1282,8 @@ def inference_paraformer_online(
|
||||
try:
|
||||
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
|
||||
except:
|
||||
raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
|
||||
# raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
|
||||
raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
|
||||
if raw_inputs.ndim == 2:
|
||||
raw_inputs = raw_inputs[:, 0]
|
||||
raw_inputs = torch.tensor(raw_inputs)
|
||||
|
||||
@ -18,7 +18,7 @@ from funasr.build_utils.build_optimizer import build_optimizer
|
||||
from funasr.build_utils.build_scheduler import build_scheduler
|
||||
from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
|
||||
from funasr.modules.lora.utils import mark_only_lora_as_trainable
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
||||
from funasr.torch_utils.model_summary import model_summary
|
||||
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
||||
|
||||
@ -27,11 +27,11 @@ class Speech2DiarizationEEND:
|
||||
"""Speech2Diarlization class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> import numpy as np
|
||||
>>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
|
||||
>>> profile = np.load("profiles.npy")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2diar(audio, profile)
|
||||
{"spk1": [(int, int), ...], ...}
|
||||
|
||||
@ -109,11 +109,11 @@ class Speech2DiarizationSOND:
|
||||
"""Speech2Xvector class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> import numpy as np
|
||||
>>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
|
||||
>>> profile = np.load("profiles.npy")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2diar(audio, profile)
|
||||
{"spk1": [(int, int), ...], ...}
|
||||
|
||||
|
||||
@ -15,7 +15,8 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile
|
||||
# import librosa
|
||||
import librosa
|
||||
import torch
|
||||
from scipy.signal import medfilt
|
||||
|
||||
@ -144,7 +145,9 @@ def inference_sond(
|
||||
# read waveform file
|
||||
example = [load_bytes(x) if isinstance(x, bytes) else x
|
||||
for x in example]
|
||||
example = [soundfile.read(x)[0] if isinstance(x, str) else x
|
||||
# example = [librosa.load(x)[0] if isinstance(x, str) else x
|
||||
# for x in example]
|
||||
example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
|
||||
for x in example]
|
||||
# convert torch tensor to numpy array
|
||||
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
|
||||
|
||||
@ -20,9 +20,9 @@ class SpeechSeparator:
|
||||
"""SpeechSeparator class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> separated_wavs = speech_separator(audio)
|
||||
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
@ -104,7 +104,12 @@ def inference_ss(
|
||||
ss_results = speech_separator(**batch)
|
||||
|
||||
for spk in range(num_spks):
|
||||
sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
|
||||
# sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
|
||||
try:
|
||||
librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
|
||||
except:
|
||||
print("To write wav by librosa, you should install librosa<=0.8.0")
|
||||
raise
|
||||
torch.cuda.empty_cache()
|
||||
return ss_results
|
||||
|
||||
|
||||
@ -22,9 +22,9 @@ class Speech2Xvector:
|
||||
"""Speech2Xvector class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2xvector(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
|
||||
@ -9,9 +9,9 @@ from typing import Optional
|
||||
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.cleaner import TextCleaner
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
from funasr.tokenizer.cleaner import TextCleaner
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from funasr.build_utils.build_model import build_model
|
||||
from funasr.build_utils.build_optimizer import build_optimizer
|
||||
from funasr.build_utils.build_scheduler import build_scheduler
|
||||
from funasr.build_utils.build_trainer import build_trainer
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
||||
from funasr.torch_utils.model_summary import model_summary
|
||||
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
||||
|
||||
@ -23,9 +23,9 @@ class Speech2VadSegment:
|
||||
"""Speech2VadSegment class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2segment(audio)
|
||||
[[10, 230], [245, 450], ...]
|
||||
|
||||
@ -118,9 +118,9 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
"""Speech2VadSegmentOnline class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> import librosa
|
||||
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> audio, rate = librosa.load("speech.wav")
|
||||
>>> speech2segment(audio)
|
||||
[[10, 230], [245, 450], ...]
|
||||
|
||||
|
||||
@ -246,14 +246,11 @@ class Trainer:
|
||||
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
|
||||
if iepoch != start_epoch:
|
||||
logging.info(
|
||||
"{}/{}epoch started. Estimated time to finish: {}".format(
|
||||
"{}/{}epoch started. Estimated time to finish: {} hours".format(
|
||||
iepoch,
|
||||
trainer_options.max_epoch,
|
||||
humanfriendly.format_timespan(
|
||||
(time.perf_counter() - start_time)
|
||||
/ (iepoch - start_epoch)
|
||||
* (trainer_options.max_epoch - iepoch + 1)
|
||||
),
|
||||
(time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
|
||||
trainer_options.max_epoch - iepoch + 1),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
74
funasr/datasets/data_sampler.py
Normal file
74
funasr/datasets/data_sampler.py
Normal file
@ -0,0 +1,74 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
class BatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
self.pre_idx = -1
|
||||
self.dataset = dataset
|
||||
self.total_samples = len(dataset)
|
||||
# self.batch_size_type = args.batch_size_type
|
||||
# self.batch_size = args.batch_size
|
||||
# self.sort_size = args.sort_size
|
||||
# self.max_length_token = args.max_length_token
|
||||
self.batch_size_type = batch_size_type
|
||||
self.batch_size = batch_size
|
||||
self.sort_size = sort_size
|
||||
self.max_length_token = kwargs.get("max_length_token", 5000)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def __iter__(self):
|
||||
print("in sampler")
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples-1) // self.sort_size + 1
|
||||
print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
datalen_with_index = []
|
||||
for i in range(self.sort_size):
|
||||
idx = iter * self.sort_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
|
||||
self.dataset.indexed_dataset[idx_map]["target_len"]
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
for item in datalen_with_index_sort:
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_length_token:
|
||||
continue
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
if self.batch_size_type == 'token':
|
||||
max_token_padding *= max_token_cur
|
||||
if max_token_padding <= self.batch_size:
|
||||
batch.append(idx)
|
||||
max_token = max_token_cur
|
||||
num_sample += 1
|
||||
else:
|
||||
yield batch
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
53
funasr/datasets/dataloader_fn.py
Normal file
53
funasr/datasets/dataloader_fn.py
Normal file
@ -0,0 +1,53 @@
|
||||
|
||||
import torch
|
||||
from funasr.datasets.dataset_jsonl import AudioDataset
|
||||
from funasr.datasets.data_sampler import BatchSampler
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
collate_fn = None
|
||||
# collate_fn = collate_fn,
|
||||
|
||||
jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
|
||||
|
||||
frontend = WavFrontend()
|
||||
token_type = 'char'
|
||||
bpemodel = None
|
||||
delimiter = None
|
||||
space_symbol = "<space>"
|
||||
non_linguistic_symbols = None
|
||||
g2p_type = None
|
||||
|
||||
tokenizer = build_tokenizer(
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
g2p_type=g2p_type,
|
||||
)
|
||||
token_list = ""
|
||||
unk_symbol = "<unk>"
|
||||
|
||||
token_id_converter = TokenIDConverter(
|
||||
token_list=token_list,
|
||||
unk_symbol=unk_symbol,
|
||||
)
|
||||
|
||||
dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
|
||||
batch_sampler = BatchSampler(dataset)
|
||||
dataloader_tr = torch.utils.data.DataLoader(dataset,
|
||||
collate_fn=dataset.collator,
|
||||
batch_sampler=batch_sampler,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True)
|
||||
|
||||
print(len(dataset))
|
||||
for i in range(3):
|
||||
print(i)
|
||||
for data in dataloader_tr:
|
||||
print(len(data), data)
|
||||
# data_iter = iter(dataloader_tr)
|
||||
# data = next(data_iter)
|
||||
pass
|
||||
@ -16,8 +16,10 @@ from typing import Dict
|
||||
from typing import Mapping
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import h5py
|
||||
try:
|
||||
import h5py
|
||||
except:
|
||||
print("If you want use h5py dataset, please pip install h5py, and try it again")
|
||||
import humanfriendly
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
|
||||
124
funasr/datasets/dataset_jsonl.py
Normal file
124
funasr/datasets/dataset_jsonl.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import json
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import kaldiio
|
||||
import librosa
|
||||
|
||||
|
||||
|
||||
def load_audio(audio_path: str, fs: int=16000):
|
||||
audio = None
|
||||
if audio_path.startswith("oss:"):
|
||||
pass
|
||||
elif audio_path.startswith("odps:"):
|
||||
pass
|
||||
else:
|
||||
if ".ark:" in audio_path:
|
||||
audio = kaldiio.load_mat(audio_path)
|
||||
else:
|
||||
audio, fs = librosa.load(audio_path, sr=fs)
|
||||
return audio
|
||||
|
||||
def extract_features(data, date_type: str="sound", frontend=None):
|
||||
if date_type == "sound":
|
||||
feat, feats_lens = frontend(data, len(data))
|
||||
feat = feat[0, :, :]
|
||||
else:
|
||||
feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
|
||||
return feat, feats_lens
|
||||
|
||||
|
||||
|
||||
class IndexedDatasetJsonl(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
# data_parallel_size = dist.get_world_size()
|
||||
data_parallel_size = 1
|
||||
contents = []
|
||||
with open(path, encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())
|
||||
if "text" in data: # for sft
|
||||
self.contents.append(data['text'])
|
||||
if "source" in data: # for speech lab pretrain
|
||||
prompt = data["prompt"]
|
||||
source = data["source"]
|
||||
target = data["target"]
|
||||
source_len = data["source_len"]
|
||||
target_len = data["target_len"]
|
||||
|
||||
contents.append({"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
"source_len": source_len,
|
||||
"target_len": target_len,
|
||||
}
|
||||
)
|
||||
|
||||
self.contents = []
|
||||
total_num = len(contents)
|
||||
num_per_rank = total_num // data_parallel_size
|
||||
# rank = dist.get_rank()
|
||||
rank = 0
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.contents[index]
|
||||
|
||||
|
||||
class AudioDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, path, frontend=None, tokenizer=None):
|
||||
super().__init__()
|
||||
self.indexed_dataset = IndexedDatasetJsonl(path)
|
||||
self.frontend = frontend.forward
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
self.tokenizer = tokenizer
|
||||
self.int_pad_value = -1
|
||||
self.float_pad_value = 0.0
|
||||
|
||||
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indexed_dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.indexed_dataset[index]
|
||||
source = item["source"]
|
||||
data_src = load_audio(source, fs=self.fs)
|
||||
speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
|
||||
target = item["target"]
|
||||
text = self.tokenizer.encode(target)
|
||||
text_lengths = len(text)
|
||||
text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
|
||||
return {"speech": speech,
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
}
|
||||
|
||||
|
||||
def collator(self, samples: list=None):
|
||||
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
for key in sample.keys():
|
||||
if key not in outputs:
|
||||
outputs[key] = []
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
for key, data_list in outputs.items():
|
||||
if data_list[0].dtype.kind == "i":
|
||||
pad_value = self.int_pad_value
|
||||
else:
|
||||
pad_value = self.float_pad_value
|
||||
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
|
||||
return samples
|
||||
@ -14,7 +14,8 @@ import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
# import librosa
|
||||
import librosa
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
import os.path
|
||||
|
||||
@ -70,7 +71,8 @@ def load_wav(input):
|
||||
try:
|
||||
return torchaudio.load(input)[0].numpy()
|
||||
except:
|
||||
waveform, _ = soundfile.read(input, dtype='float32')
|
||||
# waveform, _ = librosa.load(input, dtype='float32')
|
||||
waveform, _ = librosa.load(input, dtype='float32')
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform[:, 0]
|
||||
return np.expand_dims(waveform, axis=0)
|
||||
|
||||
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from funasr.datasets.large_datasets.dataset import Dataset
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
def read_symbol_table(symbol_table_file):
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import soundfile
|
||||
# import librosa
|
||||
import librosa
|
||||
from kaldiio import ReadHelper
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
@ -128,7 +129,8 @@ class AudioDataset(IterableDataset):
|
||||
try:
|
||||
waveform, sampling_rate = torchaudio.load(path)
|
||||
except:
|
||||
waveform, sampling_rate = soundfile.read(path, dtype='float32')
|
||||
# waveform, sampling_rate = librosa.load(path, dtype='float32')
|
||||
waveform, sampling_rate = librosa.load(path, dtype='float32')
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform[:, 0]
|
||||
waveform = np.expand_dims(waveform, axis=0)
|
||||
|
||||
@ -10,12 +10,12 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import soundfile
|
||||
import librosa
|
||||
import jieba
|
||||
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.cleaner import TextCleaner
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
from funasr.tokenizer.cleaner import TextCleaner
|
||||
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
|
||||
|
||||
class AbsPreprocessor(ABC):
|
||||
@ -284,7 +284,7 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
|
||||
rir_path = np.random.choice(self.rirs)
|
||||
if rir_path is not None:
|
||||
rir, _ = soundfile.read(
|
||||
rir, _ = librosa.load(
|
||||
rir_path, dtype=np.float64, always_2d=True
|
||||
)
|
||||
|
||||
@ -310,28 +310,31 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
noise_db = np.random.uniform(
|
||||
self.noise_db_low, self.noise_db_high
|
||||
)
|
||||
with soundfile.SoundFile(noise_path) as f:
|
||||
if f.frames == nsamples:
|
||||
noise = f.read(dtype=np.float64, always_2d=True)
|
||||
elif f.frames < nsamples:
|
||||
offset = np.random.randint(0, nsamples - f.frames)
|
||||
# noise: (Time, Nmic)
|
||||
noise = f.read(dtype=np.float64, always_2d=True)
|
||||
# Repeat noise
|
||||
noise = np.pad(
|
||||
noise,
|
||||
[(offset, nsamples - f.frames - offset), (0, 0)],
|
||||
mode="wrap",
|
||||
)
|
||||
else:
|
||||
offset = np.random.randint(0, f.frames - nsamples)
|
||||
f.seek(offset)
|
||||
# noise: (Time, Nmic)
|
||||
noise = f.read(
|
||||
nsamples, dtype=np.float64, always_2d=True
|
||||
)
|
||||
if len(noise) != nsamples:
|
||||
raise RuntimeError(f"Something wrong: {noise_path}")
|
||||
|
||||
audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
|
||||
frames = len(audio_data[0])
|
||||
if frames == nsamples:
|
||||
noise = audio_data
|
||||
elif frames < nsamples:
|
||||
offset = np.random.randint(0, nsamples - frames)
|
||||
# noise: (Time, Nmic)
|
||||
noise = audio_data
|
||||
# Repeat noise
|
||||
noise = np.pad(
|
||||
noise,
|
||||
[(offset, nsamples - frames - offset), (0, 0)],
|
||||
mode="wrap",
|
||||
)
|
||||
else:
|
||||
noise = audio_data[:, nsamples]
|
||||
# offset = np.random.randint(0, frames - nsamples)
|
||||
# f.seek(offset)
|
||||
# noise: (Time, Nmic)
|
||||
# noise = f.read(
|
||||
# nsamples, dtype=np.float64, always_2d=True
|
||||
# )
|
||||
# if len(noise) != nsamples:
|
||||
# raise RuntimeError(f"Something wrong: {noise_path}")
|
||||
# noise: (Nmic, Time)
|
||||
noise = noise.T
|
||||
|
||||
|
||||
@ -9,11 +9,11 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import soundfile
|
||||
import librosa
|
||||
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.cleaner import TextCleaner
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||
from funasr.tokenizer.cleaner import TextCleaner
|
||||
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||
|
||||
|
||||
class AbsPreprocessor(ABC):
|
||||
@ -275,7 +275,7 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
|
||||
rir_path = np.random.choice(self.rirs)
|
||||
if rir_path is not None:
|
||||
rir, _ = soundfile.read(
|
||||
rir, _ = librosa.load(
|
||||
rir_path, dtype=np.float64, always_2d=True
|
||||
)
|
||||
|
||||
@ -301,28 +301,30 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
noise_db = np.random.uniform(
|
||||
self.noise_db_low, self.noise_db_high
|
||||
)
|
||||
with soundfile.SoundFile(noise_path) as f:
|
||||
if f.frames == nsamples:
|
||||
noise = f.read(dtype=np.float64, always_2d=True)
|
||||
elif f.frames < nsamples:
|
||||
offset = np.random.randint(0, nsamples - f.frames)
|
||||
# noise: (Time, Nmic)
|
||||
noise = f.read(dtype=np.float64, always_2d=True)
|
||||
# Repeat noise
|
||||
noise = np.pad(
|
||||
noise,
|
||||
[(offset, nsamples - f.frames - offset), (0, 0)],
|
||||
mode="wrap",
|
||||
)
|
||||
else:
|
||||
offset = np.random.randint(0, f.frames - nsamples)
|
||||
f.seek(offset)
|
||||
# noise: (Time, Nmic)
|
||||
noise = f.read(
|
||||
nsamples, dtype=np.float64, always_2d=True
|
||||
)
|
||||
if len(noise) != nsamples:
|
||||
raise RuntimeError(f"Something wrong: {noise_path}")
|
||||
audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
|
||||
frames = len(audio_data[0])
|
||||
if frames == nsamples:
|
||||
noise = audio_data
|
||||
elif frames < nsamples:
|
||||
offset = np.random.randint(0, nsamples - frames)
|
||||
# noise: (Time, Nmic)
|
||||
noise = audio_data
|
||||
# Repeat noise
|
||||
noise = np.pad(
|
||||
noise,
|
||||
[(offset, nsamples - frames - offset), (0, 0)],
|
||||
mode="wrap",
|
||||
)
|
||||
else:
|
||||
noise = audio_data[:, nsamples]
|
||||
# offset = np.random.randint(0, frames - nsamples)
|
||||
# f.seek(offset)
|
||||
# noise: (Time, Nmic)
|
||||
# noise = f.read(
|
||||
# nsamples, dtype=np.float64, always_2d=True
|
||||
# )
|
||||
# if len(noise) != nsamples:
|
||||
# raise RuntimeError(f"Something wrong: {noise_path}")
|
||||
# noise: (Nmic, Time)
|
||||
noise = noise.T
|
||||
|
||||
|
||||
@ -1,151 +0,0 @@
|
||||
import json
|
||||
from typing import Union, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from funasr.export.models import get_model
|
||||
import numpy as np
|
||||
import random
|
||||
from funasr.utils.types import str2bool, str2triple_str
|
||||
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
|
||||
# assert torch_version > 1.9
|
||||
|
||||
class ModelExport:
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Union[Path, str] = None,
|
||||
onnx: bool = True,
|
||||
device: str = "cpu",
|
||||
quant: bool = True,
|
||||
fallback_num: int = 0,
|
||||
audio_in: str = None,
|
||||
calib_num: int = 200,
|
||||
model_revision: str = None,
|
||||
):
|
||||
self.set_all_random_seed(0)
|
||||
|
||||
self.cache_dir = cache_dir
|
||||
self.export_config = dict(
|
||||
feats_dim=560,
|
||||
onnx=False,
|
||||
)
|
||||
|
||||
self.onnx = onnx
|
||||
self.device = device
|
||||
self.quant = quant
|
||||
self.fallback_num = fallback_num
|
||||
self.frontend = None
|
||||
self.audio_in = audio_in
|
||||
self.calib_num = calib_num
|
||||
self.model_revision = model_revision
|
||||
|
||||
def _export(
|
||||
self,
|
||||
model,
|
||||
model_dir: str = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
|
||||
export_dir = model_dir
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
self.export_config["model_name"] = "model"
|
||||
model = get_model(
|
||||
model,
|
||||
self.export_config,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
if self.onnx:
|
||||
self._export_onnx(model, verbose, export_dir)
|
||||
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
def _export_onnx(self, model, verbose, path):
|
||||
model._export_onnx(verbose, path)
|
||||
|
||||
def set_all_random_seed(self, seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
def parse_audio_in(self, audio_in):
|
||||
|
||||
wav_list, name_list = [], []
|
||||
if audio_in.endswith(".scp"):
|
||||
f = open(audio_in, 'r')
|
||||
lines = f.readlines()[:self.calib_num]
|
||||
for line in lines:
|
||||
name, path = line.strip().split()
|
||||
name_list.append(name)
|
||||
wav_list.append(path)
|
||||
else:
|
||||
wav_list = [audio_in,]
|
||||
name_list = ["test",]
|
||||
return wav_list, name_list
|
||||
|
||||
def load_feats(self, audio_in: str = None):
|
||||
import torchaudio
|
||||
|
||||
wav_list, name_list = self.parse_audio_in(audio_in)
|
||||
feats = []
|
||||
feats_len = []
|
||||
for line in wav_list:
|
||||
path = line.strip()
|
||||
waveform, sampling_rate = torchaudio.load(path)
|
||||
if sampling_rate != self.frontend.fs:
|
||||
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
|
||||
new_freq=self.frontend.fs)(waveform)
|
||||
fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
|
||||
feats.append(fbank)
|
||||
feats_len.append(fbank_len)
|
||||
return feats, feats_len
|
||||
|
||||
def export(self,
|
||||
mode: str = None,
|
||||
):
|
||||
|
||||
if mode.startswith('conformer'):
|
||||
from funasr.tasks.asr import ASRTask
|
||||
config = os.path.join(model_dir, 'config.yaml')
|
||||
model_file = os.path.join(model_dir, 'model.pb')
|
||||
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
||||
model, asr_train_args = ASRTask.build_model_from_file(
|
||||
config, model_file, cmvn_file, 'cpu'
|
||||
)
|
||||
self.frontend = model.frontend
|
||||
self.export_config["feats_dim"] = 560
|
||||
|
||||
self._export(model, self.cache_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--model-name', type=str, required=True)
|
||||
parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
|
||||
parser.add_argument('--export-dir', type=str, required=True)
|
||||
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
|
||||
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
|
||||
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
|
||||
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
|
||||
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
|
||||
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
|
||||
args = parser.parse_args()
|
||||
|
||||
export_model = ModelExport(
|
||||
cache_dir=args.export_dir,
|
||||
onnx=args.type == 'onnx',
|
||||
device=args.device,
|
||||
quant=args.quantize,
|
||||
fallback_num=args.fallback_num,
|
||||
audio_in=args.audio_in,
|
||||
calib_num=args.calib_num,
|
||||
model_revision=args.model_revision,
|
||||
)
|
||||
for model_name in args.model_name:
|
||||
print("export model: {}".format(model_name))
|
||||
export_model.export(model_name)
|
||||
@ -1,7 +1,7 @@
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
|
||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
||||
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
|
||||
# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
|
||||
|
||||
from funasr.models.e2e_vad import E2EVadModel
|
||||
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
|
||||
@ -30,8 +30,8 @@ def get_model(model, export_config=None):
|
||||
return [encoder, decoder]
|
||||
elif isinstance(model, Paraformer):
|
||||
return Paraformer_export(model, **export_config)
|
||||
elif isinstance(model, Conformer_export):
|
||||
return Conformer_export(model, **export_config)
|
||||
# elif isinstance(model, Conformer_export):
|
||||
# return Conformer_export(model, **export_config)
|
||||
elif isinstance(model, E2EVadModel):
|
||||
return E2EVadModel_export(model, **export_config)
|
||||
elif isinstance(model, PunctuationModel):
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.export.utils.torch_function import MakePadMask
|
||||
from funasr.export.utils.torch_function import sequence_mask
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
|
||||
from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
|
||||
|
||||
class Conformer(nn.Module):
|
||||
"""
|
||||
export conformer into onnx format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
if isinstance(model.encoder, ConformerEncoder):
|
||||
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
|
||||
elif isinstance(model.decoder, TransformerDecoder):
|
||||
self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
|
||||
|
||||
self.feats_dim = feats_dim
|
||||
self.model_name = model_name
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
def _export_model(self, model, verbose, path):
|
||||
dummy_input = model.get_dummy_inputs()
|
||||
model_script = model
|
||||
model_path = os.path.join(path, f'{model.model_name}.onnx')
|
||||
if not os.path.exists(model_path):
|
||||
torch.onnx.export(
|
||||
model_script,
|
||||
dummy_input,
|
||||
model_path,
|
||||
verbose=verbose,
|
||||
opset_version=14,
|
||||
input_names=model.get_input_names(),
|
||||
output_names=model.get_output_names(),
|
||||
dynamic_axes=model.get_dynamic_axes()
|
||||
)
|
||||
|
||||
def _export_encoder_onnx(self, verbose, path):
|
||||
model_encoder = self.encoder
|
||||
self._export_model(model_encoder, verbose, path)
|
||||
|
||||
def _export_decoder_onnx(self, verbose, path):
|
||||
model_decoder = self.decoder
|
||||
self._export_model(model_decoder, verbose, path)
|
||||
|
||||
def _export_onnx(self, verbose, path):
|
||||
self._export_encoder_onnx(verbose, path)
|
||||
self._export_decoder_onnx(verbose, path)
|
||||
@ -1,403 +0,0 @@
|
||||
"""Positional Encoding Module."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.embedding import (
|
||||
LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
|
||||
ScaledPositionalEncoding, StreamPositionalEncoding)
|
||||
from funasr.modules.subsampling import (
|
||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
|
||||
Conv2dSubsampling8)
|
||||
from funasr.modules.subsampling_without_posenc import \
|
||||
Conv2dSubsamplingWOPosEnc
|
||||
|
||||
from funasr.export.models.language_models.subsampling import (
|
||||
OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
|
||||
OnnxConv2dSubsampling8)
|
||||
|
||||
|
||||
def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
|
||||
if isinstance(pos_emb, LegacyRelPositionalEncoding):
|
||||
return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
|
||||
elif isinstance(pos_emb, ScaledPositionalEncoding):
|
||||
return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
|
||||
elif isinstance(pos_emb, RelPositionalEncoding):
|
||||
return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
|
||||
elif isinstance(pos_emb, PositionalEncoding):
|
||||
return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
|
||||
elif isinstance(pos_emb, StreamPositionalEncoding):
|
||||
return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
|
||||
elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
|
||||
isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
|
||||
):
|
||||
return pos_emb
|
||||
else:
|
||||
raise ValueError("Embedding model is not supported.")
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, model, max_seq_len=512, use_cache=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
if not isinstance(model, nn.Embedding):
|
||||
if isinstance(model, Conv2dSubsampling):
|
||||
self.model = OnnxConv2dSubsampling(model)
|
||||
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
|
||||
elif isinstance(model, Conv2dSubsampling2):
|
||||
self.model = OnnxConv2dSubsampling2(model)
|
||||
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
|
||||
elif isinstance(model, Conv2dSubsampling6):
|
||||
self.model = OnnxConv2dSubsampling6(model)
|
||||
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
|
||||
elif isinstance(model, Conv2dSubsampling8):
|
||||
self.model = OnnxConv2dSubsampling8(model)
|
||||
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
|
||||
else:
|
||||
self.model[-1] = get_pos_emb(model[-1], max_seq_len)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
return self.model(x)
|
||||
else:
|
||||
return self.model(x, mask)
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||
|
||||
Note:
|
||||
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||
|
||||
"""
|
||||
k = prefix + "pe"
|
||||
if k in state_dict:
|
||||
state_dict.pop(k)
|
||||
|
||||
|
||||
class OnnxPositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_seq_len (int): Maximum input length.
|
||||
reverse (bool): Whether to reverse the input position. Only for
|
||||
the class LegacyRelPositionalEncoding. We remove it in the current
|
||||
class RelPositionalEncoding.
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(OnnxPositionalEncoding, self).__init__()
|
||||
self.d_model = model.d_model
|
||||
self.reverse = reverse
|
||||
self.max_seq_len = max_seq_len
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
self.pe = model.pe
|
||||
self.use_cache = use_cache
|
||||
self.model = model
|
||||
if self.use_cache:
|
||||
self.extend_pe()
|
||||
else:
|
||||
self.div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
|
||||
def extend_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
pe_length = len(self.pe[0])
|
||||
if self.max_seq_len < pe_length:
|
||||
self.pe = self.pe[:, : self.max_seq_len]
|
||||
else:
|
||||
self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
|
||||
self.pe = self.model.pe
|
||||
|
||||
def _add_pe(self, x):
|
||||
"""Computes positional encoding"""
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
|
||||
x = x * self.xscale
|
||||
x[:, :, 0::2] += torch.sin(position * self.div_term)
|
||||
x[:, :, 1::2] += torch.cos(position * self.div_term)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
"""
|
||||
if self.use_cache:
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
else:
|
||||
x = self._add_pe(x)
|
||||
return x
|
||||
|
||||
|
||||
class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
|
||||
"""Scaled positional encoding module.
|
||||
|
||||
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_seq_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_seq_len=512, use_cache=True):
|
||||
"""Initialize class."""
|
||||
super().__init__(model, max_seq_len, use_cache=use_cache)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
self.alpha.data = torch.tensor(1.0)
|
||||
|
||||
def _add_pe(self, x):
|
||||
"""Computes positional encoding"""
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
|
||||
x = x * self.alpha
|
||||
x[:, :, 0::2] += torch.sin(position * self.div_term)
|
||||
x[:, :, 1::2] += torch.cos(position * self.div_term)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
if self.use_cache:
|
||||
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||
else:
|
||||
x = self._add_pe(x)
|
||||
return x
|
||||
|
||||
|
||||
class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
|
||||
"""Relative positional encoding module (old version).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_seq_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_seq_len=512, use_cache=True):
|
||||
"""Initialize class."""
|
||||
super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
|
||||
|
||||
def _get_pe(self, x):
|
||||
"""Computes positional encoding"""
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
|
||||
pe = torch.zeros(x.shape)
|
||||
pe[:, :, 0::2] += torch.sin(position * self.div_term)
|
||||
pe[:, :, 1::2] += torch.cos(position * self.div_term)
|
||||
return pe
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
||||
|
||||
"""
|
||||
x = x * self.xscale
|
||||
if self.use_cache:
|
||||
pos_emb = self.pe[:, : x.size(1)]
|
||||
else:
|
||||
pos_emb = self._get_pe(x)
|
||||
return x, pos_emb
|
||||
|
||||
|
||||
class OnnxRelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module (new implementation).
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_seq_len (int): Maximum input length.
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_seq_len=512, use_cache=True):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(OnnxRelPositionalEncoding, self).__init__()
|
||||
self.d_model = model.d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.pe = None
|
||||
self.use_cache = use_cache
|
||||
if self.use_cache:
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
|
||||
else:
|
||||
self.div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in https://arxiv.org/abs/1901.02860
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def _get_pe(self, x):
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
theta = (
|
||||
torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(theta)
|
||||
pe_positive[:, 1::2] = torch.cos(theta)
|
||||
pe_negative[:, 0::2] = -1 * torch.sin(theta)
|
||||
pe_negative[:, 1::2] = torch.cos(theta)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in https://arxiv.org/abs/1901.02860
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
return torch.cat([pe_positive, pe_negative], dim=1)
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache=True):
|
||||
"""Add positional encoding.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
"""
|
||||
x = x * self.xscale
|
||||
if self.use_cache:
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
||||
]
|
||||
else:
|
||||
pos_emb = self._get_pe(x)
|
||||
return x, pos_emb
|
||||
|
||||
|
||||
class OnnxStreamPositionalEncoding(torch.nn.Module):
|
||||
"""Streaming Positional encoding."""
|
||||
|
||||
def __init__(self, model, max_seq_len=5000, use_cache=True):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(StreamPositionalEncoding, self).__init__()
|
||||
self.use_cache = use_cache
|
||||
self.d_model = model.d_model
|
||||
self.xscale = model.xscale
|
||||
self.pe = model.pe
|
||||
self.use_cache = use_cache
|
||||
self.max_seq_len = max_seq_len
|
||||
if self.use_cache:
|
||||
self.extend_pe()
|
||||
else:
|
||||
self.div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
pe_length = len(self.pe[0])
|
||||
if self.max_seq_len < pe_length:
|
||||
self.pe = self.pe[:, : self.max_seq_len]
|
||||
else:
|
||||
self.model.extend_pe(self.max_seq_len)
|
||||
self.pe = self.model.pe
|
||||
|
||||
def _add_pe(self, x, start_idx):
|
||||
position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
x = x * self.xscale
|
||||
x[:, :, 0::2] += torch.sin(position * self.div_term)
|
||||
x[:, :, 1::2] += torch.cos(position * self.div_term)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, start_idx: int = 0):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
if self.use_cache:
|
||||
return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
|
||||
else:
|
||||
return self._add_pe(x, start_idx)
|
||||
@ -1,84 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class SequentialRNNLM(nn.Module):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.encoder = model.encoder
|
||||
self.rnn = model.rnn
|
||||
self.rnn_type = model.rnn_type
|
||||
self.decoder = model.decoder
|
||||
self.nlayers = model.nlayers
|
||||
self.nhid = model.nhid
|
||||
self.model_name = "seq_rnnlm"
|
||||
|
||||
def forward(self, y, hidden1, hidden2=None):
|
||||
# batch_score function.
|
||||
emb = self.encoder(y)
|
||||
if self.rnn_type == "LSTM":
|
||||
output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
|
||||
else:
|
||||
output, hidden1 = self.rnn(emb, hidden1)
|
||||
|
||||
decoded = self.decoder(
|
||||
output.contiguous().view(output.size(0) * output.size(1), output.size(2))
|
||||
)
|
||||
if self.rnn_type == "LSTM":
|
||||
return (
|
||||
decoded.view(output.size(0), output.size(1), decoded.size(1)),
|
||||
hidden1,
|
||||
hidden2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
decoded.view(output.size(0), output.size(1), decoded.size(1)),
|
||||
hidden1,
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
tgt = torch.LongTensor([0, 1]).unsqueeze(0)
|
||||
hidden = torch.randn(self.nlayers, 1, self.nhid)
|
||||
if self.rnn_type == "LSTM":
|
||||
return (tgt, hidden, hidden)
|
||||
else:
|
||||
return (tgt, hidden)
|
||||
|
||||
def get_input_names(self):
|
||||
if self.rnn_type == "LSTM":
|
||||
return ["x", "in_hidden1", "in_hidden2"]
|
||||
else:
|
||||
return ["x", "in_hidden1"]
|
||||
|
||||
def get_output_names(self):
|
||||
if self.rnn_type == "LSTM":
|
||||
return ["y", "out_hidden1", "out_hidden2"]
|
||||
else:
|
||||
return ["y", "out_hidden1"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
ret = {
|
||||
"x": {0: "x_batch", 1: "x_length"},
|
||||
"y": {0: "y_batch"},
|
||||
"in_hidden1": {1: "hidden1_batch"},
|
||||
"out_hidden1": {1: "out_hidden1_batch"},
|
||||
}
|
||||
if self.rnn_type == "LSTM":
|
||||
ret.update(
|
||||
{
|
||||
"in_hidden2": {1: "hidden2_batch"},
|
||||
"out_hidden2": {1: "out_hidden2_batch"},
|
||||
}
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_model_config(self, path):
|
||||
return {
|
||||
"use_lm": True,
|
||||
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
|
||||
"lm_type": "SequentialRNNLM",
|
||||
"rnn_type": self.rnn_type,
|
||||
"nhid": self.nhid,
|
||||
"nlayers": self.nlayers,
|
||||
}
|
||||
@ -1,185 +0,0 @@
|
||||
"""Subsampling layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class OnnxConv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
self.conv = model.conv
|
||||
self.out = model.out
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 4.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 4.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :-2:2][:, :-2:2]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get item.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class OnnxConv2dSubsampling2(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
self.conv = model.conv
|
||||
self.out = model.out
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 2.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 2.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :-2:2][:, :-2:1]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get item.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class OnnxConv2dSubsampling6(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
self.conv = model.conv
|
||||
self.out = model.out
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 6.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 6.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :-2:2][:, :-4:3]
|
||||
|
||||
|
||||
class OnnxConv2dSubsampling8(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
self.conv = model.conv
|
||||
self.out = model.out
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 8.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 8.
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]
|
||||
@ -1,110 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.vgg2l import import VGG2L
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
from funasr.modules.subsampling import (
|
||||
Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
|
||||
|
||||
from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
|
||||
from funasr.export.models.language_models.embed import Embedding
|
||||
from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
|
||||
|
||||
from funasr.export.utils.torch_function import MakePadMask
|
||||
|
||||
class TransformerLM(nn.Module, AbsExportModel):
|
||||
def __init__(self, model, max_seq_len=512, **kwargs):
|
||||
super().__init__()
|
||||
self.embed = Embedding(model.embed, max_seq_len)
|
||||
self.encoder = model.encoder
|
||||
self.decoder = model.decoder
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
# replace multihead attention module into customized module.
|
||||
for i, d in enumerate(self.encoder.encoders):
|
||||
# d is EncoderLayer
|
||||
if isinstance(d.self_attn, MultiHeadedAttention):
|
||||
d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
|
||||
self.encoder.encoders[i] = OnnxEncoderLayer(d)
|
||||
|
||||
self.model_name = "transformer_lm"
|
||||
self.num_heads = self.encoder.encoders[0].self_attn.h
|
||||
self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
|
||||
|
||||
def prepare_mask(self, mask):
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask = mask[:, None, :]
|
||||
mask = 1 - mask
|
||||
return mask * -10000.0
|
||||
|
||||
def forward(self, y, cache):
|
||||
feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
|
||||
mask = self.make_pad_mask(feats_length) # (B, T)
|
||||
mask = (y != 0) * mask
|
||||
|
||||
xs = self.embed(y)
|
||||
# forward_one_step of Encoder
|
||||
if isinstance(
|
||||
self.encoder.embed,
|
||||
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
|
||||
):
|
||||
xs, mask = self.encoder.embed(xs, mask)
|
||||
else:
|
||||
xs = self.encoder.embed(xs)
|
||||
|
||||
new_cache = []
|
||||
mask = self.prepare_mask(mask)
|
||||
for c, e in zip(cache, self.encoder.encoders):
|
||||
xs, mask = e(xs, mask, c)
|
||||
new_cache.append(xs)
|
||||
|
||||
if self.encoder.normalize_before:
|
||||
xs = self.encoder.after_norm(xs)
|
||||
|
||||
h = self.decoder(xs[:, -1])
|
||||
return h, new_cache
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
tgt = torch.LongTensor([1]).unsqueeze(0)
|
||||
cache = [
|
||||
torch.zeros((1, 1, self.encoder.encoders[0].size))
|
||||
for _ in range(len(self.encoder.encoders))
|
||||
]
|
||||
return (tgt, cache)
|
||||
|
||||
def is_optimizable(self):
|
||||
return True
|
||||
|
||||
def get_input_names(self):
|
||||
return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
|
||||
ret.update(
|
||||
{
|
||||
"cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
|
||||
for d in range(len(self.encoder.encoders))
|
||||
}
|
||||
)
|
||||
ret.update(
|
||||
{
|
||||
"out_cache_%d"
|
||||
% d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
|
||||
for d in range(len(self.encoder.encoders))
|
||||
}
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_model_config(self, path):
|
||||
return {
|
||||
"use_lm": True,
|
||||
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
|
||||
"lm_type": "TransformerLM",
|
||||
"odim": self.encoder.encoders[0].size,
|
||||
"nlayers": len(self.encoder.encoders),
|
||||
}
|
||||
@ -4,7 +4,7 @@ from typing import List, Tuple, Union
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import librosa
|
||||
import librosa
|
||||
|
||||
import torch
|
||||
@ -116,7 +116,7 @@ class SoundScpReader(collections.abc.Mapping):
|
||||
def __getitem__(self, key):
|
||||
wav = self.data[key]
|
||||
if self.normalize:
|
||||
# soundfile.read normalizes data to [-1,1] if dtype is not given
|
||||
# librosa.load normalizes data to [-1,1] if dtype is not given
|
||||
array, rate = librosa.load(
|
||||
wav, sr=self.dest_sample_rate, mono=self.always_2d
|
||||
)
|
||||
|
||||
@ -5,8 +5,12 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
try:
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except:
|
||||
print("Please install torch_complex firstly")
|
||||
|
||||
|
||||
|
||||
EPS = torch.finfo(torch.double).eps
|
||||
|
||||
@ -4,8 +4,11 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except:
|
||||
print("Please install torch_complex firstly")
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.layers.complex_utils import is_complex
|
||||
from funasr.layers.inversible_interface import InversibleInterface
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
try:
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
except:
|
||||
print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
|
||||
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
|
||||
from funasr.modules.embedding import ScaledSinuEmbedding
|
||||
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
|
||||
|
||||
@ -6,12 +6,15 @@ import logging
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except:
|
||||
print("Please install torch_complex firstly")
|
||||
|
||||
from funasr.layers.log_mel import LogMel
|
||||
from funasr.layers.stft import Stft
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.modules.frontends.frontend import Frontend
|
||||
from funasr.models.frontend.frontends_utils.frontend import Frontend
|
||||
from funasr.utils.get_default_kwargs import get_default_kwargs
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
@ -4,12 +4,12 @@ from typing import Tuple
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from funasr.modules.frontends.beamformer import apply_beamforming_vector
|
||||
from funasr.modules.frontends.beamformer import get_mvdr_vector
|
||||
from funasr.modules.frontends.beamformer import (
|
||||
from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
|
||||
from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
|
||||
from funasr.models.frontend.frontends_utils.beamformer import (
|
||||
get_power_spectral_density_matrix, # noqa: H301
|
||||
)
|
||||
from funasr.modules.frontends.mask_estimator import MaskEstimator
|
||||
from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from pytorch_wpe import wpe_one_iteration
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr.modules.frontends.mask_estimator import MaskEstimator
|
||||
from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr.modules.frontends.dnn_beamformer import DNN_Beamformer
|
||||
from funasr.modules.frontends.dnn_wpe import DNN_WPE
|
||||
from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
|
||||
from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
|
||||
|
||||
|
||||
class Frontend(nn.Module):
|
||||
@ -10,7 +10,7 @@ import humanfriendly
|
||||
import torch
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.modules.frontends.frontend import Frontend
|
||||
from funasr.models.frontend.frontends_utils.frontend import Frontend
|
||||
from funasr.modules.nets_utils import pad_list
|
||||
from funasr.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
@ -145,9 +145,12 @@ class WavFrontend(AbsFrontend):
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
if batch_size == 1:
|
||||
feats_pad = feats[0][None, :, :]
|
||||
else:
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
def forward_fbank(
|
||||
|
||||
@ -9,7 +9,7 @@ import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import soundfile as sf
|
||||
import librosa as sf
|
||||
import io
|
||||
from functools import lru_cache
|
||||
|
||||
@ -67,18 +67,18 @@ def load_wav(wav_rxfilename, start=0, end=None):
|
||||
# input piped command
|
||||
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
|
||||
stdout=subprocess.PIPE)
|
||||
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
|
||||
data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
|
||||
dtype='float32')
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
elif wav_rxfilename == '-':
|
||||
# stdin
|
||||
data, samplerate = sf.read(sys.stdin, dtype='float32')
|
||||
data, samplerate = sf.load(sys.stdin, dtype='float32')
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
else:
|
||||
# normal wav file
|
||||
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
|
||||
data, samplerate = sf.load(wav_rxfilename, start=start, stop=end)
|
||||
return data, samplerate
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ Supports real-time streaming speech recognition, uses non-streaming models for e
|
||||
#### Server Deployment
|
||||
|
||||
```shell
|
||||
cd funasr/runtime/python/websocket
|
||||
cd runtime/python/websocket
|
||||
python funasr_wss_server.py --port 10095
|
||||
```
|
||||
|
||||
@ -161,4 +161,4 @@ If you want to train from scratch, usually for academic models, you can start tr
|
||||
cd egs/aishell/paraformer
|
||||
. ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
|
||||
```
|
||||
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
|
||||
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
|
||||
##### 服务端部署
|
||||
```shell
|
||||
cd funasr/runtime/python/websocket
|
||||
cd runtime/python/websocket
|
||||
python funasr_wss_server.py --port 10095
|
||||
```
|
||||
|
||||
@ -161,4 +161,4 @@ cd egs/aishell/paraformer
|
||||
. ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
|
||||
```
|
||||
|
||||
更多例子可以参考([点击此处](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html))
|
||||
更多例子可以参考([点击此处](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html))
|
||||
|
||||
@ -76,7 +76,7 @@ from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.modules.subsampling import Conv1dSubsampling
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
|
||||
@ -25,7 +25,7 @@ from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
|
||||
@ -17,7 +17,7 @@ from funasr.train.abs_model import LanguageModel
|
||||
from funasr.models.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.models.transformer_lm import TransformerLM
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
|
||||
@ -16,7 +16,7 @@ from funasr.train.abs_model import PunctuationModel
|
||||
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
|
||||
@ -71,7 +71,7 @@ from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.modules.subsampling import Conv1dSubsampling
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
|
||||
@ -76,7 +76,7 @@ from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.modules.subsampling import Conv1dSubsampling
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
|
||||
@ -3,11 +3,11 @@ from typing import Iterable
|
||||
from typing import Union
|
||||
|
||||
|
||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
||||
from funasr.text.char_tokenizer import CharTokenizer
|
||||
from funasr.text.phoneme_tokenizer import PhonemeTokenizer
|
||||
from funasr.text.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||
from funasr.text.word_tokenizer import WordTokenizer
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.char_tokenizer import CharTokenizer
|
||||
from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
|
||||
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||
from funasr.tokenizer.word_tokenizer import WordTokenizer
|
||||
|
||||
|
||||
def build_tokenizer(
|
||||
@ -5,7 +5,7 @@ from typing import Union
|
||||
import warnings
|
||||
|
||||
|
||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
class CharTokenizer(AbsTokenizer):
|
||||
@ -10,7 +10,7 @@ import warnings
|
||||
# import g2p_en
|
||||
import jamo
|
||||
|
||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
g2p_choices = [
|
||||
@ -107,7 +107,7 @@ def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> Lis
|
||||
List[str]: List of phoneme + prosody symbols.
|
||||
|
||||
Examples:
|
||||
>>> from funasr.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
|
||||
>>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
|
||||
>>> pyopenjtalk_g2p_prosody("こんにちは。")
|
||||
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
class SentencepiecesTokenizer(AbsTokenizer):
|
||||
@ -5,7 +5,7 @@ from typing import Union
|
||||
import warnings
|
||||
|
||||
|
||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
class WordTokenizer(AbsTokenizer):
|
||||
@ -278,14 +278,11 @@ class Trainer:
|
||||
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
|
||||
if iepoch != start_epoch:
|
||||
logging.info(
|
||||
"{}/{}epoch started. Estimated time to finish: {}".format(
|
||||
"{}/{}epoch started. Estimated time to finish: {} hours".format(
|
||||
iepoch,
|
||||
trainer_options.max_epoch,
|
||||
humanfriendly.format_timespan(
|
||||
(time.perf_counter() - start_time)
|
||||
/ (iepoch - start_epoch)
|
||||
* (trainer_options.max_epoch - iepoch + 1)
|
||||
),
|
||||
(time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
|
||||
trainer_options.max_epoch - iepoch + 1),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@ -5,7 +5,7 @@ import struct
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torchaudio
|
||||
import soundfile
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
from modelscope.utils.logger import get_logger
|
||||
@ -139,7 +139,7 @@ def get_sr_from_wav(fname: str):
|
||||
try:
|
||||
audio, fs = torchaudio.load(fname)
|
||||
except:
|
||||
audio, fs = soundfile.read(fname)
|
||||
audio, fs = librosa.load(fname)
|
||||
break
|
||||
if audio_type.rfind(".scp") >= 0:
|
||||
with open(fname, encoding="utf-8") as f:
|
||||
|
||||
@ -5,7 +5,7 @@ from multiprocessing import Pool
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
|
||||
@ -46,7 +46,7 @@ def wav2num_frame(wav_path, frontend_conf):
|
||||
try:
|
||||
waveform, sampling_rate = torchaudio.load(wav_path)
|
||||
except:
|
||||
waveform, sampling_rate = soundfile.read(wav_path)
|
||||
waveform, sampling_rate = librosa.load(wav_path)
|
||||
waveform = np.expand_dims(waveform, axis=0)
|
||||
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
|
||||
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
|
||||
|
||||
@ -12,7 +12,7 @@ import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import librosa as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
import logging
|
||||
@ -43,7 +43,7 @@ def sv_preprocess(inputs: Union[np.ndarray, list]):
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
|
||||
@ -3,7 +3,7 @@ import codecs
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
import edit_distance
|
||||
# import edit_distance
|
||||
from itertools import zip_longest
|
||||
|
||||
|
||||
@ -160,112 +160,112 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
|
||||
return res
|
||||
|
||||
|
||||
class AverageShiftCalculator():
|
||||
def __init__(self):
|
||||
logging.warning("Calculating average shift.")
|
||||
def __call__(self, file1, file2):
|
||||
uttid_list1, ts_dict1 = self.read_timestamps(file1)
|
||||
uttid_list2, ts_dict2 = self.read_timestamps(file2)
|
||||
uttid_intersection = self._intersection(uttid_list1, uttid_list2)
|
||||
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
|
||||
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
|
||||
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
|
||||
|
||||
def _intersection(self, list1, list2):
|
||||
set1 = set(list1)
|
||||
set2 = set(list2)
|
||||
if set1 == set2:
|
||||
logging.warning("Uttid same checked.")
|
||||
return set1
|
||||
itsc = list(set1 & set2)
|
||||
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
|
||||
return itsc
|
||||
|
||||
def read_timestamps(self, file):
|
||||
# read timestamps file in standard format
|
||||
uttid_list = []
|
||||
ts_dict = {}
|
||||
with codecs.open(file, 'r') as fin:
|
||||
for line in fin.readlines():
|
||||
text = ''
|
||||
ts_list = []
|
||||
line = line.rstrip()
|
||||
uttid = line.split()[0]
|
||||
uttid_list.append(uttid)
|
||||
body = " ".join(line.split()[1:])
|
||||
for pd in body.split(';'):
|
||||
if not len(pd): continue
|
||||
# pdb.set_trace()
|
||||
char, start, end = pd.lstrip(" ").split(' ')
|
||||
text += char + ','
|
||||
ts_list.append((float(start), float(end)))
|
||||
# ts_lists.append(ts_list)
|
||||
ts_dict[uttid] = (text[:-1], ts_list)
|
||||
logging.warning("File {} read done.".format(file))
|
||||
return uttid_list, ts_dict
|
||||
|
||||
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
shift_time = 0
|
||||
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
|
||||
num_tokens = len(filtered_timestamp_list1)
|
||||
return shift_time, num_tokens
|
||||
|
||||
def as_cal(self, uttid_list, ts_dict1, ts_dict2):
|
||||
# calculate average shift between timestamp1 and timestamp2
|
||||
# when characters differ, use edit distance alignment
|
||||
# and calculate the error between the same characters
|
||||
self._accumlated_shift = 0
|
||||
self._accumlated_tokens = 0
|
||||
self.max_shift = 0
|
||||
self.max_shift_uttid = None
|
||||
for uttid in uttid_list:
|
||||
(t1, ts1) = ts_dict1[uttid]
|
||||
(t2, ts2) = ts_dict2[uttid]
|
||||
_align, _align2, _align3 = [], [], []
|
||||
fts1, fts2 = [], []
|
||||
_t1, _t2 = [], []
|
||||
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
|
||||
s = sm.get_opcodes()
|
||||
for j in range(len(s)):
|
||||
if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
_align.append(0)
|
||||
if s[j][0] == "replace" or s[j][0] == "delete":
|
||||
_align3.append(0)
|
||||
elif s[j][0] == "equal":
|
||||
_align.append(1)
|
||||
_align3.append(1)
|
||||
else:
|
||||
continue
|
||||
# use s to index t2
|
||||
for a, ts , t in zip(_align, ts2, t2.split(',')):
|
||||
if a:
|
||||
fts2.append(ts)
|
||||
_t2.append(t)
|
||||
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
|
||||
s = sm2.get_opcodes()
|
||||
for j in range(len(s)):
|
||||
if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
_align2.append(0)
|
||||
elif s[j][0] == "equal":
|
||||
_align2.append(1)
|
||||
else:
|
||||
continue
|
||||
# use s2 tp index t1
|
||||
for a, ts, t in zip(_align3, ts1, t1.split(',')):
|
||||
if a:
|
||||
fts1.append(ts)
|
||||
_t1.append(t)
|
||||
if len(fts1) == len(fts2):
|
||||
shift_time, num_tokens = self._shift(fts1, fts2)
|
||||
self._accumlated_shift += shift_time
|
||||
self._accumlated_tokens += num_tokens
|
||||
if shift_time/num_tokens > self.max_shift:
|
||||
self.max_shift = shift_time/num_tokens
|
||||
self.max_shift_uttid = uttid
|
||||
else:
|
||||
logging.warning("length mismatch")
|
||||
return self._accumlated_shift / self._accumlated_tokens
|
||||
# class AverageShiftCalculator():
|
||||
# def __init__(self):
|
||||
# logging.warning("Calculating average shift.")
|
||||
# def __call__(self, file1, file2):
|
||||
# uttid_list1, ts_dict1 = self.read_timestamps(file1)
|
||||
# uttid_list2, ts_dict2 = self.read_timestamps(file2)
|
||||
# uttid_intersection = self._intersection(uttid_list1, uttid_list2)
|
||||
# res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
|
||||
# logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
|
||||
# logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
|
||||
#
|
||||
# def _intersection(self, list1, list2):
|
||||
# set1 = set(list1)
|
||||
# set2 = set(list2)
|
||||
# if set1 == set2:
|
||||
# logging.warning("Uttid same checked.")
|
||||
# return set1
|
||||
# itsc = list(set1 & set2)
|
||||
# logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
|
||||
# return itsc
|
||||
#
|
||||
# def read_timestamps(self, file):
|
||||
# # read timestamps file in standard format
|
||||
# uttid_list = []
|
||||
# ts_dict = {}
|
||||
# with codecs.open(file, 'r') as fin:
|
||||
# for line in fin.readlines():
|
||||
# text = ''
|
||||
# ts_list = []
|
||||
# line = line.rstrip()
|
||||
# uttid = line.split()[0]
|
||||
# uttid_list.append(uttid)
|
||||
# body = " ".join(line.split()[1:])
|
||||
# for pd in body.split(';'):
|
||||
# if not len(pd): continue
|
||||
# # pdb.set_trace()
|
||||
# char, start, end = pd.lstrip(" ").split(' ')
|
||||
# text += char + ','
|
||||
# ts_list.append((float(start), float(end)))
|
||||
# # ts_lists.append(ts_list)
|
||||
# ts_dict[uttid] = (text[:-1], ts_list)
|
||||
# logging.warning("File {} read done.".format(file))
|
||||
# return uttid_list, ts_dict
|
||||
#
|
||||
# def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
# shift_time = 0
|
||||
# for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
|
||||
# shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
|
||||
# num_tokens = len(filtered_timestamp_list1)
|
||||
# return shift_time, num_tokens
|
||||
#
|
||||
# # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
|
||||
# # # calculate average shift between timestamp1 and timestamp2
|
||||
# # # when characters differ, use edit distance alignment
|
||||
# # # and calculate the error between the same characters
|
||||
# # self._accumlated_shift = 0
|
||||
# # self._accumlated_tokens = 0
|
||||
# # self.max_shift = 0
|
||||
# # self.max_shift_uttid = None
|
||||
# # for uttid in uttid_list:
|
||||
# # (t1, ts1) = ts_dict1[uttid]
|
||||
# # (t2, ts2) = ts_dict2[uttid]
|
||||
# # _align, _align2, _align3 = [], [], []
|
||||
# # fts1, fts2 = [], []
|
||||
# # _t1, _t2 = [], []
|
||||
# # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
|
||||
# # s = sm.get_opcodes()
|
||||
# # for j in range(len(s)):
|
||||
# # if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
# # _align.append(0)
|
||||
# # if s[j][0] == "replace" or s[j][0] == "delete":
|
||||
# # _align3.append(0)
|
||||
# # elif s[j][0] == "equal":
|
||||
# # _align.append(1)
|
||||
# # _align3.append(1)
|
||||
# # else:
|
||||
# # continue
|
||||
# # # use s to index t2
|
||||
# # for a, ts , t in zip(_align, ts2, t2.split(',')):
|
||||
# # if a:
|
||||
# # fts2.append(ts)
|
||||
# # _t2.append(t)
|
||||
# # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
|
||||
# # s = sm2.get_opcodes()
|
||||
# # for j in range(len(s)):
|
||||
# # if s[j][0] == "replace" or s[j][0] == "insert":
|
||||
# # _align2.append(0)
|
||||
# # elif s[j][0] == "equal":
|
||||
# # _align2.append(1)
|
||||
# # else:
|
||||
# # continue
|
||||
# # # use s2 tp index t1
|
||||
# # for a, ts, t in zip(_align3, ts1, t1.split(',')):
|
||||
# # if a:
|
||||
# # fts1.append(ts)
|
||||
# # _t1.append(t)
|
||||
# # if len(fts1) == len(fts2):
|
||||
# # shift_time, num_tokens = self._shift(fts1, fts2)
|
||||
# # self._accumlated_shift += shift_time
|
||||
# # self._accumlated_tokens += num_tokens
|
||||
# # if shift_time/num_tokens > self.max_shift:
|
||||
# # self.max_shift = shift_time/num_tokens
|
||||
# # self.max_shift_uttid = uttid
|
||||
# # else:
|
||||
# # logging.warning("length mismatch")
|
||||
# # return self._accumlated_shift / self._accumlated_tokens
|
||||
|
||||
|
||||
def convert_external_alphas(alphas_file, text_file, output_file):
|
||||
@ -311,10 +311,10 @@ SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mode == 'cal_aas':
|
||||
asc = AverageShiftCalculator()
|
||||
asc(args.input, args.input2)
|
||||
elif args.mode == 'read_ext_alphas':
|
||||
# if args.mode == 'cal_aas':
|
||||
# asc = AverageShiftCalculator()
|
||||
# asc(args.input, args.input2)
|
||||
if args.mode == 'read_ext_alphas':
|
||||
convert_external_alphas(args.input, args.input2, args.output)
|
||||
else:
|
||||
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
|
||||
|
||||
@ -11,7 +11,7 @@ import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
import librosa
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ def compute_fbank(wav_file,
|
||||
try:
|
||||
waveform, audio_sr = torchaudio.load(wav_file)
|
||||
except:
|
||||
waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
|
||||
waveform, audio_sr = librosa.load(wav_file, dtype='float32')
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform[:, 0]
|
||||
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
|
||||
@ -191,7 +191,7 @@ def wav2num_frame(wav_path, frontend_conf):
|
||||
try:
|
||||
waveform, sampling_rate = torchaudio.load(wav_path)
|
||||
except:
|
||||
waveform, sampling_rate = soundfile.read(wav_path)
|
||||
waveform, sampling_rate = librosa.load(wav_path)
|
||||
waveform = torch.tensor(np.expand_dims(waveform, axis=0))
|
||||
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
|
||||
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
try:
|
||||
import ffmpeg
|
||||
except:
|
||||
print("Please Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.")
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -1 +1 @@
|
||||
0.8.5
|
||||
0.8.6
|
||||
|
||||
@ -94,7 +94,9 @@ Introduction to run_server.sh parameters:
|
||||
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
|
||||
--itn-dir modelscope model ID or local model path.
|
||||
--port: Port number that the server listens on. Default is 10095.
|
||||
--decoder-thread-num: Number of inference threads that the server starts. Default is 8.
|
||||
--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
|
||||
--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model.
|
||||
The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
|
||||
--io-thread-num: Number of IO threads that the server starts. Default is 1.
|
||||
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
|
||||
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.
|
||||
|
||||
@ -73,7 +73,9 @@ Introduction to run_server.sh parameters:
|
||||
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
|
||||
--itn-dir modelscope model ID or local model path.
|
||||
--port: Port number that the server listens on. Default is 10095.
|
||||
--decoder-thread-num: Number of inference threads that the server starts. Default is 8.
|
||||
--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
|
||||
--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model.
|
||||
The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
|
||||
--io-thread-num: Number of IO threads that the server starts. Default is 1.
|
||||
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
|
||||
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.
|
||||
|
||||
@ -158,7 +158,9 @@ nohup bash run_server.sh \
|
||||
--punc-quant True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--itn-dir modelscope model ID 或者 本地模型路径
|
||||
--port 服务端监听的端口号,默认为 10095
|
||||
--decoder-thread-num 服务端启动的推理线程数,默认为 8
|
||||
--decoder-thread-num 服务端线程池个数(支持的最大并发路数),默认为 8
|
||||
--model-thread-num 每路识别的内部线程数(控制ONNX模型的并行),默认为 1,
|
||||
其中建议 decoder-thread-num*model-thread-num 等于总线程数
|
||||
--io-thread-num 服务端启动的IO线程数,默认为 1
|
||||
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
|
||||
--keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
|
||||
@ -175,11 +175,14 @@ nohup bash run_server.sh \
|
||||
--lm-dir modelscope model ID 或者 本地模型路径
|
||||
--itn-dir modelscope model ID 或者 本地模型路径
|
||||
--port 服务端监听的端口号,默认为 10095
|
||||
--decoder-thread-num 服务端启动的推理线程数,默认为 8
|
||||
--decoder-thread-num 服务端线程池个数(支持的最大并发路数),默认为 8
|
||||
--model-thread-num 每路识别的内部线程数(控制ONNX模型的并行),默认为 1,
|
||||
其中建议 decoder-thread-num*model-thread-num 等于总线程数
|
||||
--io-thread-num 服务端启动的IO线程数,默认为 1
|
||||
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
|
||||
--keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
--hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20),如果客户端提供热词,则与客户端提供的热词合并一起使用。
|
||||
--hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20),
|
||||
如果客户端提供热词,则与客户端提供的热词合并一起使用,服务端热词全局生效,客户端热词只针对对应客户端生效。
|
||||
```
|
||||
|
||||
### 关闭FunASR服务
|
||||
|
||||
@ -111,7 +111,9 @@ nohup bash run_server_2pass.sh \
|
||||
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
|
||||
--itn-dir modelscope model ID or local model path.
|
||||
--port: Port number that the server listens on. Default is 10095.
|
||||
--decoder-thread-num: Number of inference threads that the server starts. Default is 8.
|
||||
--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
|
||||
--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model.
|
||||
The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
|
||||
--io-thread-num: Number of IO threads that the server starts. Default is 1.
|
||||
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
|
||||
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.
|
||||
|
||||
@ -120,11 +120,14 @@ nohup bash run_server_2pass.sh \
|
||||
--punc-quant True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--itn-dir modelscope model ID 或者 本地模型路径
|
||||
--port 服务端监听的端口号,默认为 10095
|
||||
--decoder-thread-num 服务端启动的推理线程数,默认为 8
|
||||
--decoder-thread-num 服务端线程池个数(支持的最大并发路数),默认为 8
|
||||
--model-thread-num 每路识别的内部线程数(控制ONNX模型的并行),默认为 1,
|
||||
其中建议 decoder-thread-num*model-thread-num 等于总线程数
|
||||
--io-thread-num 服务端启动的IO线程数,默认为 1
|
||||
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
|
||||
--keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
--hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20),如果客户端提供热词,则与客户端提供的热词合并一起使用。
|
||||
--hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20),
|
||||
如果客户端提供热词,则与客户端提供的热词合并一起使用,服务端热词全局生效,客户端热词只针对对应客户端生效。
|
||||
```
|
||||
|
||||
### 关闭FunASR服务
|
||||
|
||||
@ -16,7 +16,7 @@ git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
### Install the requirements for server
|
||||
|
||||
```shell
|
||||
cd funasr/runtime/python/websocket
|
||||
cd runtime/python/websocket
|
||||
pip install -r requirements_server.txt
|
||||
```
|
||||
|
||||
|
||||
@ -53,13 +53,13 @@ parser.add_argument("--ncpu",
|
||||
help="cpu cores")
|
||||
parser.add_argument("--certfile",
|
||||
type=str,
|
||||
default="../ssl_key/server.crt",
|
||||
default="../../ssl_key/server.crt",
|
||||
required=False,
|
||||
help="certfile for ssl")
|
||||
|
||||
parser.add_argument("--keyfile",
|
||||
type=str,
|
||||
default="../ssl_key/server.key",
|
||||
default="../../ssl_key/server.key",
|
||||
required=False,
|
||||
help="keyfile for ssl")
|
||||
args = parser.parse_args()
|
||||
|
||||
36
setup.py
36
setup.py
@ -10,36 +10,36 @@ from setuptools import setup
|
||||
|
||||
requirements = {
|
||||
"install": [
|
||||
"setuptools>=38.5.1",
|
||||
# "setuptools>=38.5.1",
|
||||
"humanfriendly",
|
||||
"scipy>=1.4.1",
|
||||
"librosa",
|
||||
"jamo", # For kss
|
||||
# "jamo", # For kss
|
||||
"PyYAML>=5.1.2",
|
||||
"soundfile>=0.12.1",
|
||||
"h5py>=3.1.0",
|
||||
# "soundfile>=0.12.1",
|
||||
# "h5py>=3.1.0",
|
||||
"kaldiio>=2.17.0",
|
||||
"torch_complex",
|
||||
"nltk>=3.4.5",
|
||||
# "torch_complex",
|
||||
# "nltk>=3.4.5",
|
||||
# ASR
|
||||
"sentencepiece",
|
||||
"sentencepiece", # train
|
||||
"jieba",
|
||||
"rotary_embedding_torch",
|
||||
"ffmpeg",
|
||||
# "rotary_embedding_torch",
|
||||
# "ffmpeg-python",
|
||||
# TTS
|
||||
"pypinyin>=0.44.0",
|
||||
"espnet_tts_frontend",
|
||||
# "pypinyin>=0.44.0",
|
||||
# "espnet_tts_frontend",
|
||||
# ENH
|
||||
"pytorch_wpe",
|
||||
# "pytorch_wpe",
|
||||
"editdistance>=0.5.2",
|
||||
"tensorboard",
|
||||
"g2p",
|
||||
"nara_wpe",
|
||||
# "g2p",
|
||||
# "nara_wpe",
|
||||
# PAI
|
||||
"oss2",
|
||||
"edit-distance",
|
||||
"textgrid",
|
||||
"protobuf",
|
||||
# "edit-distance",
|
||||
# "textgrid",
|
||||
# "protobuf",
|
||||
"tqdm",
|
||||
"hdbscan",
|
||||
"umap",
|
||||
@ -104,7 +104,7 @@ setup(
|
||||
name="funasr",
|
||||
version=version,
|
||||
url="https://github.com/alibaba-damo-academy/FunASR.git",
|
||||
author="Speech Lab of DAMO Academy, Alibaba Group",
|
||||
author="Speech Lab of Alibaba Group",
|
||||
author_email="funasr@list.alibaba-inc.com",
|
||||
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
|
||||
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
|
||||
|
||||
3
web-pages/.browserslistrc
Normal file
3
web-pages/.browserslistrc
Normal file
@ -0,0 +1,3 @@
|
||||
> 1%
|
||||
last 2 versions
|
||||
not ie < 11
|
||||
7
web-pages/.editorconfig
Normal file
7
web-pages/.editorconfig
Normal file
@ -0,0 +1,7 @@
|
||||
[*.{js,jsx,ts,tsx,vue}]
|
||||
charset = utf-8
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
7
web-pages/.eslintignore
Normal file
7
web-pages/.eslintignore
Normal file
@ -0,0 +1,7 @@
|
||||
node_modules
|
||||
dist
|
||||
fonts
|
||||
*.md
|
||||
*.woff
|
||||
*.ttf
|
||||
public
|
||||
24
web-pages/.eslintrc.js
Normal file
24
web-pages/.eslintrc.js
Normal file
@ -0,0 +1,24 @@
|
||||
module.exports = {
|
||||
root: true,
|
||||
env: {
|
||||
node: true
|
||||
},
|
||||
extends: ['plugin:vue/essential', '@vue/standard'],
|
||||
parserOptions: {
|
||||
parser: '@babel/eslint-parser'
|
||||
},
|
||||
rules: {
|
||||
'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
|
||||
'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
|
||||
'quote-props': 'off',
|
||||
indent: 'off',
|
||||
// "vue/script-indent": [
|
||||
// "error",
|
||||
// 4,
|
||||
// {
|
||||
// baseIndent: 1,
|
||||
// },
|
||||
// ],
|
||||
'vue/order-in-components': 'error'
|
||||
}
|
||||
}
|
||||
13
web-pages/babel.config.js
Normal file
13
web-pages/babel.config.js
Normal file
@ -0,0 +1,13 @@
|
||||
module.exports = {
|
||||
presets: ['@vue/cli-plugin-babel/preset'],
|
||||
plugins: [
|
||||
[
|
||||
'import',
|
||||
{
|
||||
libraryName: 'ant-design-vue',
|
||||
libraryDirectory: 'es',
|
||||
style: 'css'
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
12
web-pages/jsconfig.json
Normal file
12
web-pages/jsconfig.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es5",
|
||||
"module": "esnext",
|
||||
"baseUrl": "./",
|
||||
"moduleResolution": "node",
|
||||
"paths": {
|
||||
"@/*": ["src/*"]
|
||||
},
|
||||
"lib": ["esnext", "dom", "dom.iterable", "scripthost"]
|
||||
}
|
||||
}
|
||||
11
web-pages/mock/index.js
Normal file
11
web-pages/mock/index.js
Normal file
@ -0,0 +1,11 @@
|
||||
import Mock from 'mockjs'
|
||||
|
||||
import getUserInfo from './user/getUserInfo.js'
|
||||
import getMenuList from './user/getMenuList.js'
|
||||
Mock.setup({
|
||||
timeout: 500
|
||||
})
|
||||
|
||||
console.log('启动mock请求数据')
|
||||
getUserInfo(Mock)
|
||||
getMenuList(Mock)
|
||||
1
web-pages/mock/user/getMenuList.js
Normal file
1
web-pages/mock/user/getMenuList.js
Normal file
@ -0,0 +1 @@
|
||||
export default Mock => {}
|
||||
1
web-pages/mock/user/getUserInfo.js
Normal file
1
web-pages/mock/user/getUserInfo.js
Normal file
@ -0,0 +1 @@
|
||||
export default Mock => {}
|
||||
3
web-pages/mock/util.js
Normal file
3
web-pages/mock/util.js
Normal file
@ -0,0 +1,3 @@
|
||||
export function randomNum (n, m) {
|
||||
return Math.floor(Math.random() * (m - n + 1) + n)
|
||||
}
|
||||
9857
web-pages/package-lock.json
generated
Normal file
9857
web-pages/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
44
web-pages/package.json
Normal file
44
web-pages/package.json
Normal file
@ -0,0 +1,44 @@
|
||||
{
|
||||
"name": "template-vue",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "vue-cli-service serve --env VUE_APP_frontendConfigUrl=/template-vue/frontend-config-dev.json",
|
||||
"devmock": "vue-cli-service serve --env VUE_APP_frontendConfigUrl=/template-vue/frontend-config-devmock.json --env VUE_APP_env=devmock",
|
||||
"example": "vue-cli-service build --env VUE_APP_frontendConfigUrl=/template-vue/frontend-config-dev.json --env NODE_ENV=production"
|
||||
},
|
||||
"dependencies": {
|
||||
"ant-design-vue": "1.7.5",
|
||||
"@liveqing/liveplayer": "2.7.10",
|
||||
"axios": "0.19.2",
|
||||
"core-js": "3.22.5",
|
||||
"mockjs": "1.1.0",
|
||||
"vue": "2.6.14",
|
||||
"vue-router": "3.5.3",
|
||||
"vuex": "3.6.2",
|
||||
"swiper": "5.4.5",
|
||||
"vuex-persistedstate": "3.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "7.17.10",
|
||||
"@babel/eslint-parser": "7.17.0",
|
||||
"@vue/cli-plugin-babel": "5.0.4",
|
||||
"@vue/cli-plugin-eslint": "5.0.4",
|
||||
"@vue/cli-plugin-router": "5.0.4",
|
||||
"@vue/cli-plugin-vuex": "5.0.4",
|
||||
"@vue/cli-service": "5.0.4",
|
||||
"@vue/eslint-config-standard": "6.1.0",
|
||||
"babel-plugin-component": "1.1.1",
|
||||
"babel-plugin-import": "1.12.2",
|
||||
"compression-webpack-plugin": "3.1.0",
|
||||
"css-unicode-loader": "1.0.3",
|
||||
"eslint": "7.32.0",
|
||||
"eslint-plugin-import": "2.26.0",
|
||||
"eslint-plugin-node": "11.1.0",
|
||||
"eslint-plugin-promise": "5.2.0",
|
||||
"eslint-plugin-vue": "8.7.1",
|
||||
"sass": "1.51.0",
|
||||
"sass-loader": "12.6.0",
|
||||
"vue-template-compiler": "2.6.14"
|
||||
}
|
||||
}
|
||||
1
web-pages/public/decoder.js
Normal file
1
web-pages/public/decoder.js
Normal file
File diff suppressed because one or more lines are too long
BIN
web-pages/public/favicon.ico
Normal file
BIN
web-pages/public/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.3 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user