Merge remote-tracking branch 'refs/remotes/origin/main'

update contextual forward
This commit is contained in:
shixian.shi 2023-11-23 20:40:15 +08:00
commit adc88bd9e7
167 changed files with 13757 additions and 1319 deletions

View File

@ -1,7 +1,10 @@
import argparse import argparse
import tqdm import tqdm
import codecs import codecs
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
class Segment(object): class Segment(object):

View File

@ -6,7 +6,10 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
def get_args(): def get_args():

View File

@ -6,7 +6,12 @@ import argparse
import codecs import codecs
from distutils.util import strtobool from distutils.util import strtobool
from pathlib import Path from pathlib import Path
import textgrid
try:
import textgrid
except:
raise "Please install textgrid firstly: pip install textgrid"
import pdb import pdb
import numpy as np import numpy as np
import sys import sys

View File

@ -34,8 +34,8 @@ from funasr.modules.beam_search.beam_search_transducer import Hypothesis as Hypo
from funasr.modules.scorers.ctc import CTCPrefixScorer from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.build_utils.build_asr_model import frontend_choices from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.device_funcs import to_device
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
@ -44,9 +44,9 @@ class Speech2Text:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb") >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -251,9 +251,9 @@ class Speech2TextParaformer:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -625,9 +625,9 @@ class Speech2TextParaformerOnline:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth") >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -876,9 +876,9 @@ class Speech2TextUniASR:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -1106,9 +1106,9 @@ class Speech2TextMFCCA:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -1637,9 +1637,9 @@ class Speech2TextSAASR:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb") >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]
@ -1885,9 +1885,9 @@ class Speech2TextWhisper:
"""Speech2Text class """Speech2Text class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb") >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2text(audio) >>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]

View File

@ -20,7 +20,8 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import soundfile # import librosa
import librosa
import yaml import yaml
from funasr.bin.asr_infer import Speech2Text from funasr.bin.asr_infer import Speech2Text
@ -1281,7 +1282,8 @@ def inference_paraformer_online(
try: try:
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
except: except:
raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0] # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
if raw_inputs.ndim == 2: if raw_inputs.ndim == 2:
raw_inputs = raw_inputs[:, 0] raw_inputs = raw_inputs[:, 0]
raw_inputs = torch.tensor(raw_inputs) raw_inputs = torch.tensor(raw_inputs)

View File

@ -18,7 +18,7 @@ from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
from funasr.modules.lora.utils import mark_only_lora_as_trainable from funasr.modules.lora.utils import mark_only_lora_as_trainable
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version from funasr.torch_utils.pytorch_version import pytorch_cudnn_version

View File

@ -27,11 +27,11 @@ class Speech2DiarizationEEND:
"""Speech2Diarlization class """Speech2Diarlization class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> import numpy as np >>> import numpy as np
>>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb") >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy") >>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile) >>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...} {"spk1": [(int, int), ...], ...}
@ -109,11 +109,11 @@ class Speech2DiarizationSOND:
"""Speech2Xvector class """Speech2Xvector class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> import numpy as np >>> import numpy as np
>>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb") >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy") >>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2diar(audio, profile) >>> speech2diar(audio, profile)
{"spk1": [(int, int), ...], ...} {"spk1": [(int, int), ...], ...}

View File

@ -15,7 +15,8 @@ from typing import Tuple
from typing import Union from typing import Union
import numpy as np import numpy as np
import soundfile # import librosa
import librosa
import torch import torch
from scipy.signal import medfilt from scipy.signal import medfilt
@ -144,7 +145,9 @@ def inference_sond(
# read waveform file # read waveform file
example = [load_bytes(x) if isinstance(x, bytes) else x example = [load_bytes(x) if isinstance(x, bytes) else x
for x in example] for x in example]
example = [soundfile.read(x)[0] if isinstance(x, str) else x # example = [librosa.load(x)[0] if isinstance(x, str) else x
# for x in example]
example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
for x in example] for x in example]
# convert torch tensor to numpy array # convert torch tensor to numpy array
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x example = [x.numpy() if isinstance(example[0], torch.Tensor) else x

View File

@ -20,9 +20,9 @@ class SpeechSeparator:
"""SpeechSeparator class """SpeechSeparator class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech_separator = MossFormer("ss_config.yml", "ss.pt") >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> separated_wavs = speech_separator(audio) >>> separated_wavs = speech_separator(audio)
""" """

View File

@ -13,7 +13,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
import soundfile as sf import librosa
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse from funasr.utils import config_argparse
@ -104,7 +104,12 @@ def inference_ss(
ss_results = speech_separator(**batch) ss_results = speech_separator(**batch)
for spk in range(num_spks): for spk in range(num_spks):
sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate) # sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
try:
librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
except:
print("To write wav by librosa, you should install librosa<=0.8.0")
raise
torch.cuda.empty_cache() torch.cuda.empty_cache()
return ss_results return ss_results

View File

@ -22,9 +22,9 @@ class Speech2Xvector:
"""Speech2Xvector class """Speech2Xvector class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb") >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2xvector(audio) >>> speech2xvector(audio)
[(text, token, token_int, hypothesis object), ...] [(text, token, token_int, hypothesis object), ...]

View File

@ -9,9 +9,9 @@ from typing import Optional
from funasr.utils.cli_utils import get_commandline_args from funasr.utils.cli_utils import get_commandline_args
from funasr.text.build_tokenizer import build_tokenizer from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner from funasr.tokenizer.cleaner import TextCleaner
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.utils.types import str2bool from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none from funasr.utils.types import str_or_none

View File

@ -11,7 +11,7 @@ import numpy as np
import torch import torch
from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.device_funcs import to_device

View File

@ -17,7 +17,7 @@ from funasr.build_utils.build_model import build_model
from funasr.build_utils.build_optimizer import build_optimizer from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer from funasr.build_utils.build_trainer import build_trainer
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version from funasr.torch_utils.pytorch_version import pytorch_cudnn_version

View File

@ -23,9 +23,9 @@ class Speech2VadSegment:
"""Speech2VadSegment class """Speech2VadSegment class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt") >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio) >>> speech2segment(audio)
[[10, 230], [245, 450], ...] [[10, 230], [245, 450], ...]
@ -118,9 +118,9 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class """Speech2VadSegmentOnline class
Examples: Examples:
>>> import soundfile >>> import librosa
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt") >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav") >>> audio, rate = librosa.load("speech.wav")
>>> speech2segment(audio) >>> speech2segment(audio)
[[10, 230], [245, 450], ...] [[10, 230], [245, 450], ...]

View File

@ -246,14 +246,11 @@ class Trainer:
for iepoch in range(start_epoch, trainer_options.max_epoch + 1): for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch: if iepoch != start_epoch:
logging.info( logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format( "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch, iepoch,
trainer_options.max_epoch, trainer_options.max_epoch,
humanfriendly.format_timespan( (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
(time.perf_counter() - start_time) trainer_options.max_epoch - iepoch + 1),
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
) )
) )
else: else:

View File

@ -0,0 +1,74 @@
import torch
import numpy as np
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
# self.batch_size_type = args.batch_size_type
# self.batch_size = args.batch_size
# self.sort_size = args.sort_size
# self.max_length_token = args.max_length_token
self.batch_size_type = batch_size_type
self.batch_size = batch_size
self.sort_size = sort_size
self.max_length_token = kwargs.get("max_length_token", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle
def __len__(self):
return self.total_samples
def __iter__(self):
print("in sampler")
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
batch = []
max_token = 0
num_sample = 0
iter_num = (self.total_samples-1) // self.sort_size + 1
print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
idx = iter * self.sort_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
self.dataset.indexed_dataset[idx_map]["target_len"]
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_length_token:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_size_type == 'token':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1

View File

@ -0,0 +1,53 @@
import torch
from funasr.datasets.dataset_jsonl import AudioDataset
from funasr.datasets.data_sampler import BatchSampler
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.token_id_converter import TokenIDConverter
collate_fn = None
# collate_fn = collate_fn,
jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
frontend = WavFrontend()
token_type = 'char'
bpemodel = None
delimiter = None
space_symbol = "<space>"
non_linguistic_symbols = None
g2p_type = None
tokenizer = build_tokenizer(
token_type=token_type,
bpemodel=bpemodel,
delimiter=delimiter,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
g2p_type=g2p_type,
)
token_list = ""
unk_symbol = "<unk>"
token_id_converter = TokenIDConverter(
token_list=token_list,
unk_symbol=unk_symbol,
)
dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
batch_sampler = BatchSampler(dataset)
dataloader_tr = torch.utils.data.DataLoader(dataset,
collate_fn=dataset.collator,
batch_sampler=batch_sampler,
shuffle=False,
num_workers=0,
pin_memory=True)
print(len(dataset))
for i in range(3):
print(i)
for data in dataloader_tr:
print(len(data), data)
# data_iter = iter(dataloader_tr)
# data = next(data_iter)
pass

View File

@ -16,8 +16,10 @@ from typing import Dict
from typing import Mapping from typing import Mapping
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
try:
import h5py import h5py
except:
print("If you want use h5py dataset, please pip install h5py, and try it again")
import humanfriendly import humanfriendly
import kaldiio import kaldiio
import numpy as np import numpy as np

View File

@ -0,0 +1,124 @@
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
def load_audio(audio_path: str, fs: int=16000):
audio = None
if audio_path.startswith("oss:"):
pass
elif audio_path.startswith("odps:"):
pass
else:
if ".ark:" in audio_path:
audio = kaldiio.load_mat(audio_path)
else:
audio, fs = librosa.load(audio_path, sr=fs)
return audio
def extract_features(data, date_type: str="sound", frontend=None):
if date_type == "sound":
feat, feats_lens = frontend(data, len(data))
feat = feat[0, :, :]
else:
feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
return feat, feats_lens
class IndexedDatasetJsonl(torch.utils.data.Dataset):
def __init__(self, path):
super().__init__()
# data_parallel_size = dist.get_world_size()
data_parallel_size = 1
contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
data = json.loads(line.strip())
if "text" in data: # for sft
self.contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data["prompt"]
source = data["source"]
target = data["target"]
source_len = data["source_len"]
target_len = data["target_len"]
contents.append({"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
)
self.contents = []
total_num = len(contents)
num_per_rank = total_num // data_parallel_size
# rank = dist.get_rank()
rank = 0
# import ipdb; ipdb.set_trace()
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
return self.contents[index]
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, path, frontend=None, tokenizer=None):
super().__init__()
self.indexed_dataset = IndexedDatasetJsonl(path)
self.frontend = frontend.forward
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
self.int_pad_value = -1
self.float_pad_value = 0.0
def __len__(self):
return len(self.indexed_dataset)
def __getitem__(self, index):
item = self.indexed_dataset[index]
source = item["source"]
data_src = load_audio(source, fs=self.fs)
speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
target = item["target"]
text = self.tokenizer.encode(target)
text_lengths = len(text)
text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
return {"speech": speech,
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
}
def collator(self, samples: list=None):
outputs = {}
for sample in samples:
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if data_list[0].dtype.kind == "i":
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
return samples

View File

@ -14,7 +14,8 @@ import kaldiio
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import soundfile # import librosa
import librosa
from torch.utils.data.dataset import IterableDataset from torch.utils.data.dataset import IterableDataset
import os.path import os.path
@ -70,7 +71,8 @@ def load_wav(input):
try: try:
return torchaudio.load(input)[0].numpy() return torchaudio.load(input)[0].numpy()
except: except:
waveform, _ = soundfile.read(input, dtype='float32') # waveform, _ = librosa.load(input, dtype='float32')
waveform, _ = librosa.load(input, dtype='float32')
if waveform.ndim == 2: if waveform.ndim == 2:
waveform = waveform[:, 0] waveform = waveform[:, 0]
return np.expand_dims(waveform, axis=0) return np.expand_dims(waveform, axis=0)

View File

@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from funasr.datasets.large_datasets.dataset import Dataset from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.text.abs_tokenizer import AbsTokenizer from funasr.tokenizer.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file): def read_symbol_table(symbol_table_file):

View File

@ -7,7 +7,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torchaudio import torchaudio
import numpy as np import numpy as np
import soundfile # import librosa
import librosa
from kaldiio import ReadHelper from kaldiio import ReadHelper
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
@ -128,7 +129,8 @@ class AudioDataset(IterableDataset):
try: try:
waveform, sampling_rate = torchaudio.load(path) waveform, sampling_rate = torchaudio.load(path)
except: except:
waveform, sampling_rate = soundfile.read(path, dtype='float32') # waveform, sampling_rate = librosa.load(path, dtype='float32')
waveform, sampling_rate = librosa.load(path, dtype='float32')
if waveform.ndim == 2: if waveform.ndim == 2:
waveform = waveform[:, 0] waveform = waveform[:, 0]
waveform = np.expand_dims(waveform, axis=0) waveform = np.expand_dims(waveform, axis=0)

View File

@ -10,12 +10,12 @@ from typing import Union
import numpy as np import numpy as np
import scipy.signal import scipy.signal
import soundfile import librosa
import jieba import jieba
from funasr.text.build_tokenizer import build_tokenizer from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner from funasr.tokenizer.cleaner import TextCleaner
from funasr.text.token_id_converter import TokenIDConverter from funasr.tokenizer.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC): class AbsPreprocessor(ABC):
@ -284,7 +284,7 @@ class CommonPreprocessor(AbsPreprocessor):
if self.rirs is not None and self.rir_apply_prob >= np.random.random(): if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs) rir_path = np.random.choice(self.rirs)
if rir_path is not None: if rir_path is not None:
rir, _ = soundfile.read( rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True rir_path, dtype=np.float64, always_2d=True
) )
@ -310,28 +310,31 @@ class CommonPreprocessor(AbsPreprocessor):
noise_db = np.random.uniform( noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high self.noise_db_low, self.noise_db_high
) )
with soundfile.SoundFile(noise_path) as f:
if f.frames == nsamples: audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
noise = f.read(dtype=np.float64, always_2d=True) frames = len(audio_data[0])
elif f.frames < nsamples: if frames == nsamples:
offset = np.random.randint(0, nsamples - f.frames) noise = audio_data
elif frames < nsamples:
offset = np.random.randint(0, nsamples - frames)
# noise: (Time, Nmic) # noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True) noise = audio_data
# Repeat noise # Repeat noise
noise = np.pad( noise = np.pad(
noise, noise,
[(offset, nsamples - f.frames - offset), (0, 0)], [(offset, nsamples - frames - offset), (0, 0)],
mode="wrap", mode="wrap",
) )
else: else:
offset = np.random.randint(0, f.frames - nsamples) noise = audio_data[:, nsamples]
f.seek(offset) # offset = np.random.randint(0, frames - nsamples)
# f.seek(offset)
# noise: (Time, Nmic) # noise: (Time, Nmic)
noise = f.read( # noise = f.read(
nsamples, dtype=np.float64, always_2d=True # nsamples, dtype=np.float64, always_2d=True
) # )
if len(noise) != nsamples: # if len(noise) != nsamples:
raise RuntimeError(f"Something wrong: {noise_path}") # raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time) # noise: (Nmic, Time)
noise = noise.T noise = noise.T

View File

@ -9,11 +9,11 @@ from typing import Union
import numpy as np import numpy as np
import scipy.signal import scipy.signal
import soundfile import librosa
from funasr.text.build_tokenizer import build_tokenizer from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner from funasr.tokenizer.cleaner import TextCleaner
from funasr.text.token_id_converter import TokenIDConverter from funasr.tokenizer.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC): class AbsPreprocessor(ABC):
@ -275,7 +275,7 @@ class CommonPreprocessor(AbsPreprocessor):
if self.rirs is not None and self.rir_apply_prob >= np.random.random(): if self.rirs is not None and self.rir_apply_prob >= np.random.random():
rir_path = np.random.choice(self.rirs) rir_path = np.random.choice(self.rirs)
if rir_path is not None: if rir_path is not None:
rir, _ = soundfile.read( rir, _ = librosa.load(
rir_path, dtype=np.float64, always_2d=True rir_path, dtype=np.float64, always_2d=True
) )
@ -301,28 +301,30 @@ class CommonPreprocessor(AbsPreprocessor):
noise_db = np.random.uniform( noise_db = np.random.uniform(
self.noise_db_low, self.noise_db_high self.noise_db_low, self.noise_db_high
) )
with soundfile.SoundFile(noise_path) as f: audio_data = librosa.load(noise_path, dtype='float32')[0][None, :]
if f.frames == nsamples: frames = len(audio_data[0])
noise = f.read(dtype=np.float64, always_2d=True) if frames == nsamples:
elif f.frames < nsamples: noise = audio_data
offset = np.random.randint(0, nsamples - f.frames) elif frames < nsamples:
offset = np.random.randint(0, nsamples - frames)
# noise: (Time, Nmic) # noise: (Time, Nmic)
noise = f.read(dtype=np.float64, always_2d=True) noise = audio_data
# Repeat noise # Repeat noise
noise = np.pad( noise = np.pad(
noise, noise,
[(offset, nsamples - f.frames - offset), (0, 0)], [(offset, nsamples - frames - offset), (0, 0)],
mode="wrap", mode="wrap",
) )
else: else:
offset = np.random.randint(0, f.frames - nsamples) noise = audio_data[:, nsamples]
f.seek(offset) # offset = np.random.randint(0, frames - nsamples)
# f.seek(offset)
# noise: (Time, Nmic) # noise: (Time, Nmic)
noise = f.read( # noise = f.read(
nsamples, dtype=np.float64, always_2d=True # nsamples, dtype=np.float64, always_2d=True
) # )
if len(noise) != nsamples: # if len(noise) != nsamples:
raise RuntimeError(f"Something wrong: {noise_path}") # raise RuntimeError(f"Something wrong: {noise_path}")
# noise: (Nmic, Time) # noise: (Nmic, Time)
noise = noise.T noise = noise.T

View File

@ -1,151 +0,0 @@
import json
from typing import Union, Dict
from pathlib import Path
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ModelExport:
def __init__(
self,
cache_dir: Union[Path, str] = None,
onnx: bool = True,
device: str = "cpu",
quant: bool = True,
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
model_revision: str = None,
):
self.set_all_random_seed(0)
self.cache_dir = cache_dir
self.export_config = dict(
feats_dim=560,
onnx=False,
)
self.onnx = onnx
self.device = device
self.quant = quant
self.fallback_num = fallback_num
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
self.model_revision = model_revision
def _export(
self,
model,
model_dir: str = None,
verbose: bool = False,
):
export_dir = model_dir
os.makedirs(export_dir, exist_ok=True)
self.export_config["model_name"] = "model"
model = get_model(
model,
self.export_config,
)
model.eval()
if self.onnx:
self._export_onnx(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
def _export_onnx(self, model, verbose, path):
model._export_onnx(verbose, path)
def set_all_random_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def parse_audio_in(self, audio_in):
wav_list, name_list = [], []
if audio_in.endswith(".scp"):
f = open(audio_in, 'r')
lines = f.readlines()[:self.calib_num]
for line in lines:
name, path = line.strip().split()
name_list.append(name)
wav_list.append(path)
else:
wav_list = [audio_in,]
name_list = ["test",]
return wav_list, name_list
def load_feats(self, audio_in: str = None):
import torchaudio
wav_list, name_list = self.parse_audio_in(audio_in)
feats = []
feats_len = []
for line in wav_list:
path = line.strip()
waveform, sampling_rate = torchaudio.load(path)
if sampling_rate != self.frontend.fs:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend.fs)(waveform)
fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
feats.append(fbank)
feats_len.append(fbank_len)
return feats, feats_len
def export(self,
mode: str = None,
):
if mode.startswith('conformer'):
from funasr.tasks.asr import ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
self.export_config["feats_dim"] = 560
self._export(model, self.cache_dir)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
args = parser.parse_args()
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=args.type == 'onnx',
device=args.device,
quant=args.quantize,
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
model_revision=args.model_revision,
)
for model_name in args.model_name:
print("export model: {}".format(model_name))
export_model.export(model_name)

View File

@ -1,7 +1,7 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export # from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
from funasr.models.e2e_vad import E2EVadModel from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
@ -30,8 +30,8 @@ def get_model(model, export_config=None):
return [encoder, decoder] return [encoder, decoder]
elif isinstance(model, Paraformer): elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config) return Paraformer_export(model, **export_config)
elif isinstance(model, Conformer_export): # elif isinstance(model, Conformer_export):
return Conformer_export(model, **export_config) # return Conformer_export(model, **export_config)
elif isinstance(model, E2EVadModel): elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config) return E2EVadModel_export(model, **export_config)
elif isinstance(model, PunctuationModel): elif isinstance(model, PunctuationModel):

View File

@ -1,69 +0,0 @@
import os
import logging
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
class Conformer(nn.Module):
"""
export conformer into onnx format
"""
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
if isinstance(model.encoder, ConformerEncoder):
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
elif isinstance(model.decoder, TransformerDecoder):
self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
self.feats_dim = feats_dim
self.model_name = model_name
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
def _export_model(self, model, verbose, path):
dummy_input = model.get_dummy_inputs()
model_script = model
model_path = os.path.join(path, f'{model.model_name}.onnx')
if not os.path.exists(model_path):
torch.onnx.export(
model_script,
dummy_input,
model_path,
verbose=verbose,
opset_version=14,
input_names=model.get_input_names(),
output_names=model.get_output_names(),
dynamic_axes=model.get_dynamic_axes()
)
def _export_encoder_onnx(self, verbose, path):
model_encoder = self.encoder
self._export_model(model_encoder, verbose, path)
def _export_decoder_onnx(self, verbose, path):
model_decoder = self.decoder
self._export_model(model_decoder, verbose, path)
def _export_onnx(self, verbose, path):
self._export_encoder_onnx(verbose, path)
self._export_decoder_onnx(verbose, path)

View File

@ -1,403 +0,0 @@
"""Positional Encoding Module."""
import math
import torch
import torch.nn as nn
from funasr.modules.embedding import (
LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
ScaledPositionalEncoding, StreamPositionalEncoding)
from funasr.modules.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
Conv2dSubsampling8)
from funasr.modules.subsampling_without_posenc import \
Conv2dSubsamplingWOPosEnc
from funasr.export.models.language_models.subsampling import (
OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
OnnxConv2dSubsampling8)
def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
if isinstance(pos_emb, LegacyRelPositionalEncoding):
return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, ScaledPositionalEncoding):
return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, RelPositionalEncoding):
return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, PositionalEncoding):
return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif isinstance(pos_emb, StreamPositionalEncoding):
return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
):
return pos_emb
else:
raise ValueError("Embedding model is not supported.")
class Embedding(nn.Module):
def __init__(self, model, max_seq_len=512, use_cache=True):
super().__init__()
self.model = model
if not isinstance(model, nn.Embedding):
if isinstance(model, Conv2dSubsampling):
self.model = OnnxConv2dSubsampling(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling2):
self.model = OnnxConv2dSubsampling2(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling6):
self.model = OnnxConv2dSubsampling6(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
elif isinstance(model, Conv2dSubsampling8):
self.model = OnnxConv2dSubsampling8(model)
self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
else:
self.model[-1] = get_pos_emb(model[-1], max_seq_len)
def forward(self, x, mask=None):
if mask is None:
return self.model(x)
else:
return self.model(x, mask)
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class OnnxPositionalEncoding(torch.nn.Module):
"""Positional encoding.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
reverse (bool): Whether to reverse the input position. Only for
the class LegacyRelPositionalEncoding. We remove it in the current
class RelPositionalEncoding.
"""
def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
"""Construct an PositionalEncoding object."""
super(OnnxPositionalEncoding, self).__init__()
self.d_model = model.d_model
self.reverse = reverse
self.max_seq_len = max_seq_len
self.xscale = math.sqrt(self.d_model)
self._register_load_state_dict_pre_hook(_pre_hook)
self.pe = model.pe
self.use_cache = use_cache
self.model = model
if self.use_cache:
self.extend_pe()
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
def extend_pe(self):
"""Reset the positional encodings."""
pe_length = len(self.pe[0])
if self.max_seq_len < pe_length:
self.pe = self.pe[:, : self.max_seq_len]
else:
self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
self.pe = self.model.pe
def _add_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.xscale
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
x = x * self.xscale + self.pe[:, : x.size(1)]
else:
x = self._add_pe(x)
return x
class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
"""Scaled positional encoding module.
See Sec. 3.2 https://arxiv.org/abs/1809.08895
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Initialize class."""
super().__init__(model, max_seq_len, use_cache=use_cache)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def _add_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.alpha
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
x = x + self.alpha * self.pe[:, : x.size(1)]
else:
x = self._add_pe(x)
return x
class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Initialize class."""
super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
def _get_pe(self, x):
"""Computes positional encoding"""
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe = torch.zeros(x.shape)
pe[:, :, 0::2] += torch.sin(position * self.div_term)
pe[:, :, 1::2] += torch.cos(position * self.div_term)
return pe
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
if self.use_cache:
pos_emb = self.pe[:, : x.size(1)]
else:
pos_emb = self._get_pe(x)
return x, pos_emb
class OnnxRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_seq_len (int): Maximum input length.
"""
def __init__(self, model, max_seq_len=512, use_cache=True):
"""Construct an PositionalEncoding object."""
super(OnnxRelPositionalEncoding, self).__init__()
self.d_model = model.d_model
self.xscale = math.sqrt(self.d_model)
self.pe = None
self.use_cache = use_cache
if self.use_cache:
self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def _get_pe(self, x):
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
theta = (
torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
)
pe_positive[:, 0::2] = torch.sin(theta)
pe_positive[:, 1::2] = torch.cos(theta)
pe_negative[:, 0::2] = -1 * torch.sin(theta)
pe_negative[:, 1::2] = torch.cos(theta)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
return torch.cat([pe_positive, pe_negative], dim=1)
def forward(self, x: torch.Tensor, use_cache=True):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
x = x * self.xscale
if self.use_cache:
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
else:
pos_emb = self._get_pe(x)
return x, pos_emb
class OnnxStreamPositionalEncoding(torch.nn.Module):
"""Streaming Positional encoding."""
def __init__(self, model, max_seq_len=5000, use_cache=True):
"""Construct an PositionalEncoding object."""
super(StreamPositionalEncoding, self).__init__()
self.use_cache = use_cache
self.d_model = model.d_model
self.xscale = model.xscale
self.pe = model.pe
self.use_cache = use_cache
self.max_seq_len = max_seq_len
if self.use_cache:
self.extend_pe()
else:
self.div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self):
"""Reset the positional encodings."""
pe_length = len(self.pe[0])
if self.max_seq_len < pe_length:
self.pe = self.pe[:, : self.max_seq_len]
else:
self.model.extend_pe(self.max_seq_len)
self.pe = self.model.pe
def _add_pe(self, x, start_idx):
position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
x = x * self.xscale
x[:, :, 0::2] += torch.sin(position * self.div_term)
x[:, :, 1::2] += torch.cos(position * self.div_term)
return x
def forward(self, x: torch.Tensor, start_idx: int = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
if self.use_cache:
return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
else:
return self._add_pe(x, start_idx)

View File

@ -1,84 +0,0 @@
import os
import torch
import torch.nn as nn
class SequentialRNNLM(nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.encoder = model.encoder
self.rnn = model.rnn
self.rnn_type = model.rnn_type
self.decoder = model.decoder
self.nlayers = model.nlayers
self.nhid = model.nhid
self.model_name = "seq_rnnlm"
def forward(self, y, hidden1, hidden2=None):
# batch_score function.
emb = self.encoder(y)
if self.rnn_type == "LSTM":
output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
else:
output, hidden1 = self.rnn(emb, hidden1)
decoded = self.decoder(
output.contiguous().view(output.size(0) * output.size(1), output.size(2))
)
if self.rnn_type == "LSTM":
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
hidden2,
)
else:
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
hidden1,
)
def get_dummy_inputs(self):
tgt = torch.LongTensor([0, 1]).unsqueeze(0)
hidden = torch.randn(self.nlayers, 1, self.nhid)
if self.rnn_type == "LSTM":
return (tgt, hidden, hidden)
else:
return (tgt, hidden)
def get_input_names(self):
if self.rnn_type == "LSTM":
return ["x", "in_hidden1", "in_hidden2"]
else:
return ["x", "in_hidden1"]
def get_output_names(self):
if self.rnn_type == "LSTM":
return ["y", "out_hidden1", "out_hidden2"]
else:
return ["y", "out_hidden1"]
def get_dynamic_axes(self):
ret = {
"x": {0: "x_batch", 1: "x_length"},
"y": {0: "y_batch"},
"in_hidden1": {1: "hidden1_batch"},
"out_hidden1": {1: "out_hidden1_batch"},
}
if self.rnn_type == "LSTM":
ret.update(
{
"in_hidden2": {1: "hidden2_batch"},
"out_hidden2": {1: "out_hidden2_batch"},
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "SequentialRNNLM",
"rnn_type": self.rnn_type,
"nhid": self.nhid,
"nlayers": self.nlayers,
}

View File

@ -1,185 +0,0 @@
"""Subsampling layer definition."""
import torch
class OnnxConv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class OnnxConv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:1]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class OnnxConv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-4:3]
class OnnxConv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, model):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = model.conv
self.out = model.out
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]

View File

@ -1,110 +0,0 @@
import os
import torch
import torch.nn as nn
from funasr.modules.vgg2l import import VGG2L
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.subsampling import (
Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
from funasr.export.models.language_models.embed import Embedding
from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
from funasr.export.utils.torch_function import MakePadMask
class TransformerLM(nn.Module, AbsExportModel):
def __init__(self, model, max_seq_len=512, **kwargs):
super().__init__()
self.embed = Embedding(model.embed, max_seq_len)
self.encoder = model.encoder
self.decoder = model.decoder
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
# replace multihead attention module into customized module.
for i, d in enumerate(self.encoder.encoders):
# d is EncoderLayer
if isinstance(d.self_attn, MultiHeadedAttention):
d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
self.encoder.encoders[i] = OnnxEncoderLayer(d)
self.model_name = "transformer_lm"
self.num_heads = self.encoder.encoders[0].self_attn.h
self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
if len(mask.shape) == 2:
mask = mask[:, None, None, :]
elif len(mask.shape) == 3:
mask = mask[:, None, :]
mask = 1 - mask
return mask * -10000.0
def forward(self, y, cache):
feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
mask = self.make_pad_mask(feats_length) # (B, T)
mask = (y != 0) * mask
xs = self.embed(y)
# forward_one_step of Encoder
if isinstance(
self.encoder.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
):
xs, mask = self.encoder.embed(xs, mask)
else:
xs = self.encoder.embed(xs)
new_cache = []
mask = self.prepare_mask(mask)
for c, e in zip(cache, self.encoder.encoders):
xs, mask = e(xs, mask, c)
new_cache.append(xs)
if self.encoder.normalize_before:
xs = self.encoder.after_norm(xs)
h = self.decoder(xs[:, -1])
return h, new_cache
def get_dummy_inputs(self):
tgt = torch.LongTensor([1]).unsqueeze(0)
cache = [
torch.zeros((1, 1, self.encoder.encoders[0].size))
for _ in range(len(self.encoder.encoders))
]
return (tgt, cache)
def is_optimizable(self):
return True
def get_input_names(self):
return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
def get_output_names(self):
return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
def get_dynamic_axes(self):
ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
ret.update(
{
"cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
for d in range(len(self.encoder.encoders))
}
)
ret.update(
{
"out_cache_%d"
% d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
for d in range(len(self.encoder.encoders))
}
)
return ret
def get_model_config(self, path):
return {
"use_lm": True,
"model_path": os.path.join(path, f"{self.model_name}.onnx"),
"lm_type": "TransformerLM",
"odim": self.encoder.encoders[0].size,
"nlayers": len(self.encoder.encoders),
}

View File

@ -4,7 +4,7 @@ from typing import List, Tuple, Union
import random import random
import numpy as np import numpy as np
import soundfile import librosa
import librosa import librosa
import torch import torch
@ -116,7 +116,7 @@ class SoundScpReader(collections.abc.Mapping):
def __getitem__(self, key): def __getitem__(self, key):
wav = self.data[key] wav = self.data[key]
if self.normalize: if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given # librosa.load normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load( array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=self.always_2d wav, sr=self.dest_sample_rate, mono=self.always_2d
) )

View File

@ -5,8 +5,12 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from torch_complex import functional as FC try:
from torch_complex.tensor import ComplexTensor from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
EPS = torch.finfo(torch.double).eps EPS = torch.finfo(torch.double).eps

View File

@ -4,8 +4,11 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from torch_complex.tensor import ComplexTensor
try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex from funasr.layers.complex_utils import is_complex
from funasr.layers.inversible_interface import InversibleInterface from funasr.layers.inversible_interface import InversibleInterface

View File

@ -1,8 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try:
from rotary_embedding_torch import RotaryEmbedding from rotary_embedding_torch import RotaryEmbedding
except:
print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM from funasr.modules.mossformer import FLASH_ShareA_FFConvM

View File

@ -6,12 +6,15 @@ import logging
import humanfriendly import humanfriendly
import numpy as np import numpy as np
import torch import torch
from torch_complex.tensor import ComplexTensor try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.layers.log_mel import LogMel from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask

View File

@ -4,12 +4,12 @@ from typing import Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from funasr.modules.frontends.beamformer import apply_beamforming_vector from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
from funasr.modules.frontends.beamformer import get_mvdr_vector from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
from funasr.modules.frontends.beamformer import ( from funasr.models.frontend.frontends_utils.beamformer import (
get_power_spectral_density_matrix, # noqa: H301 get_power_spectral_density_matrix, # noqa: H301
) )
from funasr.modules.frontends.mask_estimator import MaskEstimator from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor from torch_complex.tensor import ComplexTensor

View File

@ -4,7 +4,7 @@ from pytorch_wpe import wpe_one_iteration
import torch import torch
from torch_complex.tensor import ComplexTensor from torch_complex.tensor import ComplexTensor
from funasr.modules.frontends.mask_estimator import MaskEstimator from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from funasr.modules.nets_utils import make_pad_mask from funasr.modules.nets_utils import make_pad_mask

View File

@ -8,8 +8,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch_complex.tensor import ComplexTensor from torch_complex.tensor import ComplexTensor
from funasr.modules.frontends.dnn_beamformer import DNN_Beamformer from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
from funasr.modules.frontends.dnn_wpe import DNN_WPE from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
class Frontend(nn.Module): class Frontend(nn.Module):

View File

@ -10,7 +10,7 @@ import humanfriendly
import torch import torch
from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.modules.nets_utils import pad_list from funasr.modules.nets_utils import pad_list
from funasr.utils.get_default_kwargs import get_default_kwargs from funasr.utils.get_default_kwargs import get_default_kwargs

View File

@ -145,6 +145,9 @@ class WavFrontend(AbsFrontend):
feats_lens.append(feat_length) feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens) feats_lens = torch.as_tensor(feats_lens)
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
feats_pad = pad_sequence(feats, feats_pad = pad_sequence(feats,
batch_first=True, batch_first=True,
padding_value=0.0) padding_value=0.0)

View File

@ -9,7 +9,7 @@ import os
import sys import sys
import numpy as np import numpy as np
import subprocess import subprocess
import soundfile as sf import librosa as sf
import io import io
from functools import lru_cache from functools import lru_cache
@ -67,18 +67,18 @@ def load_wav(wav_rxfilename, start=0, end=None):
# input piped command # input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True, p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE) stdout=subprocess.PIPE)
data, samplerate = sf.read(io.BytesIO(p.stdout.read()), data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
dtype='float32') dtype='float32')
# cannot seek # cannot seek
data = data[start:end] data = data[start:end]
elif wav_rxfilename == '-': elif wav_rxfilename == '-':
# stdin # stdin
data, samplerate = sf.read(sys.stdin, dtype='float32') data, samplerate = sf.load(sys.stdin, dtype='float32')
# cannot seek # cannot seek
data = data[start:end] data = data[start:end]
else: else:
# normal wav file # normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) data, samplerate = sf.load(wav_rxfilename, start=start, stop=end)
return data, samplerate return data, samplerate

View File

@ -16,7 +16,7 @@ Supports real-time streaming speech recognition, uses non-streaming models for e
#### Server Deployment #### Server Deployment
```shell ```shell
cd funasr/runtime/python/websocket cd runtime/python/websocket
python funasr_wss_server.py --port 10095 python funasr_wss_server.py --port 10095
``` ```

View File

@ -17,7 +17,7 @@
##### 服务端部署 ##### 服务端部署
```shell ```shell
cd funasr/runtime/python/websocket cd runtime/python/websocket
python funasr_wss_server.py --port 10095 python funasr_wss_server.py --port 10095
``` ```

View File

@ -76,7 +76,7 @@ from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices from funasr.train.class_choices import ClassChoices

View File

@ -25,7 +25,7 @@ from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer from funasr.train.trainer import Trainer

View File

@ -17,7 +17,7 @@ from funasr.train.abs_model import LanguageModel
from funasr.models.seq_rnn_lm import SequentialRNNLM from funasr.models.seq_rnn_lm import SequentialRNNLM
from funasr.models.transformer_lm import TransformerLM from funasr.models.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer from funasr.train.trainer import Trainer

View File

@ -16,7 +16,7 @@ from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer from funasr.train.trainer import Trainer

View File

@ -71,7 +71,7 @@ from funasr.models.specaug.specaug import SpecAugLFR
from funasr.models.base_model import FunASRModel from funasr.models.base_model import FunASRModel
from funasr.modules.subsampling import Conv1dSubsampling from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer from funasr.train.trainer import Trainer

View File

@ -76,7 +76,7 @@ from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices from funasr.train.class_choices import ClassChoices

View File

@ -3,11 +3,11 @@ from typing import Iterable
from typing import Union from typing import Union
from funasr.text.abs_tokenizer import AbsTokenizer from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.text.char_tokenizer import CharTokenizer from funasr.tokenizer.char_tokenizer import CharTokenizer
from funasr.text.phoneme_tokenizer import PhonemeTokenizer from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
from funasr.text.sentencepiece_tokenizer import SentencepiecesTokenizer from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
from funasr.text.word_tokenizer import WordTokenizer from funasr.tokenizer.word_tokenizer import WordTokenizer
def build_tokenizer( def build_tokenizer(

View File

@ -5,7 +5,7 @@ from typing import Union
import warnings import warnings
from funasr.text.abs_tokenizer import AbsTokenizer from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class CharTokenizer(AbsTokenizer): class CharTokenizer(AbsTokenizer):

View File

@ -10,7 +10,7 @@ import warnings
# import g2p_en # import g2p_en
import jamo import jamo
from funasr.text.abs_tokenizer import AbsTokenizer from funasr.tokenizer.abs_tokenizer import AbsTokenizer
g2p_choices = [ g2p_choices = [
@ -107,7 +107,7 @@ def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> Lis
List[str]: List of phoneme + prosody symbols. List[str]: List of phoneme + prosody symbols.
Examples: Examples:
>>> from funasr.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody >>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
>>> pyopenjtalk_g2p_prosody("こんにちは。") >>> pyopenjtalk_g2p_prosody("こんにちは。")
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$'] ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']

View File

@ -5,7 +5,7 @@ from typing import Union
import sentencepiece as spm import sentencepiece as spm
from funasr.text.abs_tokenizer import AbsTokenizer from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class SentencepiecesTokenizer(AbsTokenizer): class SentencepiecesTokenizer(AbsTokenizer):

View File

@ -5,7 +5,7 @@ from typing import Union
import warnings import warnings
from funasr.text.abs_tokenizer import AbsTokenizer from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class WordTokenizer(AbsTokenizer): class WordTokenizer(AbsTokenizer):

View File

@ -278,14 +278,11 @@ class Trainer:
for iepoch in range(start_epoch, trainer_options.max_epoch + 1): for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch: if iepoch != start_epoch:
logging.info( logging.info(
"{}/{}epoch started. Estimated time to finish: {}".format( "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch, iepoch,
trainer_options.max_epoch, trainer_options.max_epoch,
humanfriendly.format_timespan( (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
(time.perf_counter() - start_time) trainer_options.max_epoch - iepoch + 1),
/ (iepoch - start_epoch)
* (trainer_options.max_epoch - iepoch + 1)
),
) )
) )
else: else:

View File

@ -5,7 +5,7 @@ import struct
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import torchaudio import torchaudio
import soundfile import librosa
import numpy as np import numpy as np
import pkg_resources import pkg_resources
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@ -139,7 +139,7 @@ def get_sr_from_wav(fname: str):
try: try:
audio, fs = torchaudio.load(fname) audio, fs = torchaudio.load(fname)
except: except:
audio, fs = soundfile.read(fname) audio, fs = librosa.load(fname)
break break
if audio_type.rfind(".scp") >= 0: if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f: with open(fname, encoding="utf-8") as f:

View File

@ -5,7 +5,7 @@ from multiprocessing import Pool
import kaldiio import kaldiio
import numpy as np import numpy as np
import soundfile import librosa
import torch.distributed as dist import torch.distributed as dist
import torchaudio import torchaudio
@ -46,7 +46,7 @@ def wav2num_frame(wav_path, frontend_conf):
try: try:
waveform, sampling_rate = torchaudio.load(wav_path) waveform, sampling_rate = torchaudio.load(wav_path)
except: except:
waveform, sampling_rate = soundfile.read(wav_path) waveform, sampling_rate = librosa.load(wav_path)
waveform = np.expand_dims(waveform, axis=0) waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"]) n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"] feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]

View File

@ -12,7 +12,7 @@ import os
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import soundfile as sf import librosa as sf
import torch import torch
import torchaudio import torchaudio
import logging import logging
@ -43,7 +43,7 @@ def sv_preprocess(inputs: Union[np.ndarray, list]):
for i in range(len(inputs)): for i in range(len(inputs)):
if isinstance(inputs[i], str): if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i]) file_bytes = File.read(inputs[i])
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') data, fs = sf.load(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2: if len(data.shape) == 2:
data = data[:, 0] data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0) data = torch.from_numpy(data).unsqueeze(0)

View File

@ -3,7 +3,7 @@ import codecs
import logging import logging
import argparse import argparse
import numpy as np import numpy as np
import edit_distance # import edit_distance
from itertools import zip_longest from itertools import zip_longest
@ -160,112 +160,112 @@ def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocess
return res return res
class AverageShiftCalculator(): # class AverageShiftCalculator():
def __init__(self): # def __init__(self):
logging.warning("Calculating average shift.") # logging.warning("Calculating average shift.")
def __call__(self, file1, file2): # def __call__(self, file1, file2):
uttid_list1, ts_dict1 = self.read_timestamps(file1) # uttid_list1, ts_dict1 = self.read_timestamps(file1)
uttid_list2, ts_dict2 = self.read_timestamps(file2) # uttid_list2, ts_dict2 = self.read_timestamps(file2)
uttid_intersection = self._intersection(uttid_list1, uttid_list2) # uttid_intersection = self._intersection(uttid_list1, uttid_list2)
res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2) # res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8])) # logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid)) # logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
def _intersection(self, list1, list2): # def _intersection(self, list1, list2):
set1 = set(list1) # set1 = set(list1)
set2 = set(list2) # set2 = set(list2)
if set1 == set2: # if set1 == set2:
logging.warning("Uttid same checked.") # logging.warning("Uttid same checked.")
return set1 # return set1
itsc = list(set1 & set2) # itsc = list(set1 & set2)
logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc))) # logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
return itsc # return itsc
#
def read_timestamps(self, file): # def read_timestamps(self, file):
# read timestamps file in standard format # # read timestamps file in standard format
uttid_list = [] # uttid_list = []
ts_dict = {} # ts_dict = {}
with codecs.open(file, 'r') as fin: # with codecs.open(file, 'r') as fin:
for line in fin.readlines(): # for line in fin.readlines():
text = '' # text = ''
ts_list = [] # ts_list = []
line = line.rstrip() # line = line.rstrip()
uttid = line.split()[0] # uttid = line.split()[0]
uttid_list.append(uttid) # uttid_list.append(uttid)
body = " ".join(line.split()[1:]) # body = " ".join(line.split()[1:])
for pd in body.split(';'): # for pd in body.split(';'):
if not len(pd): continue # if not len(pd): continue
# pdb.set_trace() # # pdb.set_trace()
char, start, end = pd.lstrip(" ").split(' ') # char, start, end = pd.lstrip(" ").split(' ')
text += char + ',' # text += char + ','
ts_list.append((float(start), float(end))) # ts_list.append((float(start), float(end)))
# ts_lists.append(ts_list) # # ts_lists.append(ts_list)
ts_dict[uttid] = (text[:-1], ts_list) # ts_dict[uttid] = (text[:-1], ts_list)
logging.warning("File {} read done.".format(file)) # logging.warning("File {} read done.".format(file))
return uttid_list, ts_dict # return uttid_list, ts_dict
#
def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2): # def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
shift_time = 0 # shift_time = 0
for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2): # for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1]) # shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
num_tokens = len(filtered_timestamp_list1) # num_tokens = len(filtered_timestamp_list1)
return shift_time, num_tokens # return shift_time, num_tokens
#
def as_cal(self, uttid_list, ts_dict1, ts_dict2): # # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
# calculate average shift between timestamp1 and timestamp2 # # # calculate average shift between timestamp1 and timestamp2
# when characters differ, use edit distance alignment # # # when characters differ, use edit distance alignment
# and calculate the error between the same characters # # # and calculate the error between the same characters
self._accumlated_shift = 0 # # self._accumlated_shift = 0
self._accumlated_tokens = 0 # # self._accumlated_tokens = 0
self.max_shift = 0 # # self.max_shift = 0
self.max_shift_uttid = None # # self.max_shift_uttid = None
for uttid in uttid_list: # # for uttid in uttid_list:
(t1, ts1) = ts_dict1[uttid] # # (t1, ts1) = ts_dict1[uttid]
(t2, ts2) = ts_dict2[uttid] # # (t2, ts2) = ts_dict2[uttid]
_align, _align2, _align3 = [], [], [] # # _align, _align2, _align3 = [], [], []
fts1, fts2 = [], [] # # fts1, fts2 = [], []
_t1, _t2 = [], [] # # _t1, _t2 = [], []
sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(',')) # # sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
s = sm.get_opcodes() # # s = sm.get_opcodes()
for j in range(len(s)): # # for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert": # # if s[j][0] == "replace" or s[j][0] == "insert":
_align.append(0) # # _align.append(0)
if s[j][0] == "replace" or s[j][0] == "delete": # # if s[j][0] == "replace" or s[j][0] == "delete":
_align3.append(0) # # _align3.append(0)
elif s[j][0] == "equal": # # elif s[j][0] == "equal":
_align.append(1) # # _align.append(1)
_align3.append(1) # # _align3.append(1)
else: # # else:
continue # # continue
# use s to index t2 # # # use s to index t2
for a, ts , t in zip(_align, ts2, t2.split(',')): # # for a, ts , t in zip(_align, ts2, t2.split(',')):
if a: # # if a:
fts2.append(ts) # # fts2.append(ts)
_t2.append(t) # # _t2.append(t)
sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(',')) # # sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
s = sm2.get_opcodes() # # s = sm2.get_opcodes()
for j in range(len(s)): # # for j in range(len(s)):
if s[j][0] == "replace" or s[j][0] == "insert": # # if s[j][0] == "replace" or s[j][0] == "insert":
_align2.append(0) # # _align2.append(0)
elif s[j][0] == "equal": # # elif s[j][0] == "equal":
_align2.append(1) # # _align2.append(1)
else: # # else:
continue # # continue
# use s2 tp index t1 # # # use s2 tp index t1
for a, ts, t in zip(_align3, ts1, t1.split(',')): # # for a, ts, t in zip(_align3, ts1, t1.split(',')):
if a: # # if a:
fts1.append(ts) # # fts1.append(ts)
_t1.append(t) # # _t1.append(t)
if len(fts1) == len(fts2): # # if len(fts1) == len(fts2):
shift_time, num_tokens = self._shift(fts1, fts2) # # shift_time, num_tokens = self._shift(fts1, fts2)
self._accumlated_shift += shift_time # # self._accumlated_shift += shift_time
self._accumlated_tokens += num_tokens # # self._accumlated_tokens += num_tokens
if shift_time/num_tokens > self.max_shift: # # if shift_time/num_tokens > self.max_shift:
self.max_shift = shift_time/num_tokens # # self.max_shift = shift_time/num_tokens
self.max_shift_uttid = uttid # # self.max_shift_uttid = uttid
else: # # else:
logging.warning("length mismatch") # # logging.warning("length mismatch")
return self._accumlated_shift / self._accumlated_tokens # # return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file): def convert_external_alphas(alphas_file, text_file, output_file):
@ -311,10 +311,10 @@ SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args): def main(args):
if args.mode == 'cal_aas': # if args.mode == 'cal_aas':
asc = AverageShiftCalculator() # asc = AverageShiftCalculator()
asc(args.input, args.input2) # asc(args.input, args.input2)
elif args.mode == 'read_ext_alphas': if args.mode == 'read_ext_alphas':
convert_external_alphas(args.input, args.input2, args.output) convert_external_alphas(args.input, args.input2, args.output)
else: else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES)) logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))

View File

@ -11,7 +11,7 @@ import librosa
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
import soundfile import librosa
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
@ -166,7 +166,7 @@ def compute_fbank(wav_file,
try: try:
waveform, audio_sr = torchaudio.load(wav_file) waveform, audio_sr = torchaudio.load(wav_file)
except: except:
waveform, audio_sr = soundfile.read(wav_file, dtype='float32') waveform, audio_sr = librosa.load(wav_file, dtype='float32')
if waveform.ndim == 2: if waveform.ndim == 2:
waveform = waveform[:, 0] waveform = waveform[:, 0]
waveform = torch.tensor(np.expand_dims(waveform, axis=0)) waveform = torch.tensor(np.expand_dims(waveform, axis=0))
@ -191,7 +191,7 @@ def wav2num_frame(wav_path, frontend_conf):
try: try:
waveform, sampling_rate = torchaudio.load(wav_path) waveform, sampling_rate = torchaudio.load(wav_path)
except: except:
waveform, sampling_rate = soundfile.read(wav_path) waveform, sampling_rate = librosa.load(wav_path)
waveform = torch.tensor(np.expand_dims(waveform, axis=0)) waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000. speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"]) n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])

View File

@ -1,8 +1,11 @@
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Union from typing import Union
try:
import ffmpeg
except:
print("Please Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.")
import ffmpeg
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -1 +1 @@
0.8.5 0.8.6

View File

@ -94,7 +94,9 @@ Introduction to run_server.sh parameters:
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True. --punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
--itn-dir modelscope model ID or local model path. --itn-dir modelscope model ID or local model path.
--port: Port number that the server listens on. Default is 10095. --port: Port number that the server listens on. Default is 10095.
--decoder-thread-num: Number of inference threads that the server starts. Default is 8. --decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model.
The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
--io-thread-num: Number of IO threads that the server starts. Default is 1. --io-thread-num: Number of IO threads that the server starts. Default is 1.
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close sslset 0 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close sslset 0
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.

View File

@ -73,7 +73,9 @@ Introduction to run_server.sh parameters:
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True. --punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
--itn-dir modelscope model ID or local model path. --itn-dir modelscope model ID or local model path.
--port: Port number that the server listens on. Default is 10095. --port: Port number that the server listens on. Default is 10095.
--decoder-thread-num: Number of inference threads that the server starts. Default is 8. --decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model.
The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
--io-thread-num: Number of IO threads that the server starts. Default is 1. --io-thread-num: Number of IO threads that the server starts. Default is 1.
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close sslset 0 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close sslset 0
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.

View File

@ -158,7 +158,9 @@ nohup bash run_server.sh \
--punc-quant True为量化PUNC模型False为非量化PUNC模型默认是True --punc-quant True为量化PUNC模型False为非量化PUNC模型默认是True
--itn-dir modelscope model ID 或者 本地模型路径 --itn-dir modelscope model ID 或者 本地模型路径
--port 服务端监听的端口号,默认为 10095 --port 服务端监听的端口号,默认为 10095
--decoder-thread-num 服务端启动的推理线程数,默认为 8 --decoder-thread-num 服务端线程池个数(支持的最大并发路数),默认为 8
--model-thread-num 每路识别的内部线程数(控制ONNX模型的并行),默认为 1
其中建议 decoder-thread-num*model-thread-num 等于总线程数
--io-thread-num 服务端启动的IO线程数默认为 1 --io-thread-num 服务端启动的IO线程数默认为 1
--certfile ssl的证书文件默认为../../../ssl_key/server.crt如果需要关闭ssl参数设置为0 --certfile ssl的证书文件默认为../../../ssl_key/server.crt如果需要关闭ssl参数设置为0
--keyfile ssl的密钥文件默认为../../../ssl_key/server.key --keyfile ssl的密钥文件默认为../../../ssl_key/server.key

View File

@ -175,11 +175,14 @@ nohup bash run_server.sh \
--lm-dir modelscope model ID 或者 本地模型路径 --lm-dir modelscope model ID 或者 本地模型路径
--itn-dir modelscope model ID 或者 本地模型路径 --itn-dir modelscope model ID 或者 本地模型路径
--port 服务端监听的端口号,默认为 10095 --port 服务端监听的端口号,默认为 10095
--decoder-thread-num 服务端启动的推理线程数,默认为 8 --decoder-thread-num 服务端线程池个数(支持的最大并发路数),默认为 8
--model-thread-num 每路识别的内部线程数(控制ONNX模型的并行),默认为 1
其中建议 decoder-thread-num*model-thread-num 等于总线程数
--io-thread-num 服务端启动的IO线程数默认为 1 --io-thread-num 服务端启动的IO线程数默认为 1
--certfile ssl的证书文件默认为../../../ssl_key/server.crt如果需要关闭ssl参数设置为0 --certfile ssl的证书文件默认为../../../ssl_key/server.crt如果需要关闭ssl参数设置为0
--keyfile ssl的密钥文件默认为../../../ssl_key/server.key --keyfile ssl的密钥文件默认为../../../ssl_key/server.key
--hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20),如果客户端提供热词,则与客户端提供的热词合并一起使用。 --hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20)
如果客户端提供热词,则与客户端提供的热词合并一起使用,服务端热词全局生效,客户端热词只针对对应客户端生效。
``` ```
### 关闭FunASR服务 ### 关闭FunASR服务

View File

@ -111,7 +111,9 @@ nohup bash run_server_2pass.sh \
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True. --punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
--itn-dir modelscope model ID or local model path. --itn-dir modelscope model ID or local model path.
--port: Port number that the server listens on. Default is 10095. --port: Port number that the server listens on. Default is 10095.
--decoder-thread-num: Number of inference threads that the server starts. Default is 8. --decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model.
The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
--io-thread-num: Number of IO threads that the server starts. Default is 1. --io-thread-num: Number of IO threads that the server starts. Default is 1.
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close sslset 0 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close sslset 0
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.

View File

@ -120,11 +120,14 @@ nohup bash run_server_2pass.sh \
--punc-quant True为量化PUNC模型False为非量化PUNC模型默认是True --punc-quant True为量化PUNC模型False为非量化PUNC模型默认是True
--itn-dir modelscope model ID 或者 本地模型路径 --itn-dir modelscope model ID 或者 本地模型路径
--port 服务端监听的端口号,默认为 10095 --port 服务端监听的端口号,默认为 10095
--decoder-thread-num 服务端启动的推理线程数,默认为 8 --decoder-thread-num 服务端线程池个数(支持的最大并发路数),默认为 8
--model-thread-num 每路识别的内部线程数(控制ONNX模型的并行),默认为 1
其中建议 decoder-thread-num*model-thread-num 等于总线程数
--io-thread-num 服务端启动的IO线程数默认为 1 --io-thread-num 服务端启动的IO线程数默认为 1
--certfile ssl的证书文件默认为../../../ssl_key/server.crt如果需要关闭ssl参数设置为0 --certfile ssl的证书文件默认为../../../ssl_key/server.crt如果需要关闭ssl参数设置为0
--keyfile ssl的密钥文件默认为../../../ssl_key/server.key --keyfile ssl的密钥文件默认为../../../ssl_key/server.key
--hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20),如果客户端提供热词,则与客户端提供的热词合并一起使用。 --hotword 热词文件路径,每行一个热词,格式:热词 权重(例如:阿里巴巴 20)
如果客户端提供热词,则与客户端提供的热词合并一起使用,服务端热词全局生效,客户端热词只针对对应客户端生效。
``` ```
### 关闭FunASR服务 ### 关闭FunASR服务

View File

@ -16,7 +16,7 @@ git clone https://github.com/alibaba/FunASR.git && cd FunASR
### Install the requirements for server ### Install the requirements for server
```shell ```shell
cd funasr/runtime/python/websocket cd runtime/python/websocket
pip install -r requirements_server.txt pip install -r requirements_server.txt
``` ```

View File

@ -53,13 +53,13 @@ parser.add_argument("--ncpu",
help="cpu cores") help="cpu cores")
parser.add_argument("--certfile", parser.add_argument("--certfile",
type=str, type=str,
default="../ssl_key/server.crt", default="../../ssl_key/server.crt",
required=False, required=False,
help="certfile for ssl") help="certfile for ssl")
parser.add_argument("--keyfile", parser.add_argument("--keyfile",
type=str, type=str,
default="../ssl_key/server.key", default="../../ssl_key/server.key",
required=False, required=False,
help="keyfile for ssl") help="keyfile for ssl")
args = parser.parse_args() args = parser.parse_args()

View File

@ -10,36 +10,36 @@ from setuptools import setup
requirements = { requirements = {
"install": [ "install": [
"setuptools>=38.5.1", # "setuptools>=38.5.1",
"humanfriendly", "humanfriendly",
"scipy>=1.4.1", "scipy>=1.4.1",
"librosa", "librosa",
"jamo", # For kss # "jamo", # For kss
"PyYAML>=5.1.2", "PyYAML>=5.1.2",
"soundfile>=0.12.1", # "soundfile>=0.12.1",
"h5py>=3.1.0", # "h5py>=3.1.0",
"kaldiio>=2.17.0", "kaldiio>=2.17.0",
"torch_complex", # "torch_complex",
"nltk>=3.4.5", # "nltk>=3.4.5",
# ASR # ASR
"sentencepiece", "sentencepiece", # train
"jieba", "jieba",
"rotary_embedding_torch", # "rotary_embedding_torch",
"ffmpeg", # "ffmpeg-python",
# TTS # TTS
"pypinyin>=0.44.0", # "pypinyin>=0.44.0",
"espnet_tts_frontend", # "espnet_tts_frontend",
# ENH # ENH
"pytorch_wpe", # "pytorch_wpe",
"editdistance>=0.5.2", "editdistance>=0.5.2",
"tensorboard", "tensorboard",
"g2p", # "g2p",
"nara_wpe", # "nara_wpe",
# PAI # PAI
"oss2", "oss2",
"edit-distance", # "edit-distance",
"textgrid", # "textgrid",
"protobuf", # "protobuf",
"tqdm", "tqdm",
"hdbscan", "hdbscan",
"umap", "umap",
@ -104,7 +104,7 @@ setup(
name="funasr", name="funasr",
version=version, version=version,
url="https://github.com/alibaba-damo-academy/FunASR.git", url="https://github.com/alibaba-damo-academy/FunASR.git",
author="Speech Lab of DAMO Academy, Alibaba Group", author="Speech Lab of Alibaba Group",
author_email="funasr@list.alibaba-inc.com", author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(), long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),

View File

@ -0,0 +1,3 @@
> 1%
last 2 versions
not ie < 11

7
web-pages/.editorconfig Normal file
View File

@ -0,0 +1,7 @@
[*.{js,jsx,ts,tsx,vue}]
charset = utf-8
indent_style = space
indent_size = 4
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true

7
web-pages/.eslintignore Normal file
View File

@ -0,0 +1,7 @@
node_modules
dist
fonts
*.md
*.woff
*.ttf
public

24
web-pages/.eslintrc.js Normal file
View File

@ -0,0 +1,24 @@
module.exports = {
root: true,
env: {
node: true
},
extends: ['plugin:vue/essential', '@vue/standard'],
parserOptions: {
parser: '@babel/eslint-parser'
},
rules: {
'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
'quote-props': 'off',
indent: 'off',
// "vue/script-indent": [
// "error",
// 4,
// {
// baseIndent: 1,
// },
// ],
'vue/order-in-components': 'error'
}
}

13
web-pages/babel.config.js Normal file
View File

@ -0,0 +1,13 @@
module.exports = {
presets: ['@vue/cli-plugin-babel/preset'],
plugins: [
[
'import',
{
libraryName: 'ant-design-vue',
libraryDirectory: 'es',
style: 'css'
}
]
]
}

12
web-pages/jsconfig.json Normal file
View File

@ -0,0 +1,12 @@
{
"compilerOptions": {
"target": "es5",
"module": "esnext",
"baseUrl": "./",
"moduleResolution": "node",
"paths": {
"@/*": ["src/*"]
},
"lib": ["esnext", "dom", "dom.iterable", "scripthost"]
}
}

11
web-pages/mock/index.js Normal file
View File

@ -0,0 +1,11 @@
import Mock from 'mockjs'
import getUserInfo from './user/getUserInfo.js'
import getMenuList from './user/getMenuList.js'
Mock.setup({
timeout: 500
})
console.log('启动mock请求数据')
getUserInfo(Mock)
getMenuList(Mock)

View File

@ -0,0 +1 @@
export default Mock => {}

View File

@ -0,0 +1 @@
export default Mock => {}

3
web-pages/mock/util.js Normal file
View File

@ -0,0 +1,3 @@
export function randomNum (n, m) {
return Math.floor(Math.random() * (m - n + 1) + n)
}

9857
web-pages/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

44
web-pages/package.json Normal file
View File

@ -0,0 +1,44 @@
{
"name": "template-vue",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "vue-cli-service serve --env VUE_APP_frontendConfigUrl=/template-vue/frontend-config-dev.json",
"devmock": "vue-cli-service serve --env VUE_APP_frontendConfigUrl=/template-vue/frontend-config-devmock.json --env VUE_APP_env=devmock",
"example": "vue-cli-service build --env VUE_APP_frontendConfigUrl=/template-vue/frontend-config-dev.json --env NODE_ENV=production"
},
"dependencies": {
"ant-design-vue": "1.7.5",
"@liveqing/liveplayer": "2.7.10",
"axios": "0.19.2",
"core-js": "3.22.5",
"mockjs": "1.1.0",
"vue": "2.6.14",
"vue-router": "3.5.3",
"vuex": "3.6.2",
"swiper": "5.4.5",
"vuex-persistedstate": "3.0.1"
},
"devDependencies": {
"@babel/core": "7.17.10",
"@babel/eslint-parser": "7.17.0",
"@vue/cli-plugin-babel": "5.0.4",
"@vue/cli-plugin-eslint": "5.0.4",
"@vue/cli-plugin-router": "5.0.4",
"@vue/cli-plugin-vuex": "5.0.4",
"@vue/cli-service": "5.0.4",
"@vue/eslint-config-standard": "6.1.0",
"babel-plugin-component": "1.1.1",
"babel-plugin-import": "1.12.2",
"compression-webpack-plugin": "3.1.0",
"css-unicode-loader": "1.0.3",
"eslint": "7.32.0",
"eslint-plugin-import": "2.26.0",
"eslint-plugin-node": "11.1.0",
"eslint-plugin-promise": "5.2.0",
"eslint-plugin-vue": "8.7.1",
"sass": "1.51.0",
"sass-loader": "12.6.0",
"vue-template-compiler": "2.6.14"
}
}

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

Some files were not shown because too many files have changed in this diff Show More