mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf funasr2 (#1111)
* update funasr.text -> funasr.tokenizer fix bug export
This commit is contained in:
parent
23623f3cf1
commit
7dadb793e6
@ -34,8 +34,8 @@ from funasr.modules.beam_search.beam_search_transducer import Hypothesis as Hypo
|
|||||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||||
from funasr.build_utils.build_asr_model import frontend_choices
|
from funasr.build_utils.build_asr_model import frontend_choices
|
||||||
from funasr.text.build_tokenizer import build_tokenizer
|
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||||
from funasr.text.token_id_converter import TokenIDConverter
|
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||||
from funasr.torch_utils.device_funcs import to_device
|
from funasr.torch_utils.device_funcs import to_device
|
||||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from funasr.build_utils.build_optimizer import build_optimizer
|
|||||||
from funasr.build_utils.build_scheduler import build_scheduler
|
from funasr.build_utils.build_scheduler import build_scheduler
|
||||||
from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
|
from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
|
||||||
from funasr.modules.lora.utils import mark_only_lora_as_trainable
|
from funasr.modules.lora.utils import mark_only_lora_as_trainable
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
||||||
from funasr.torch_utils.model_summary import model_summary
|
from funasr.torch_utils.model_summary import model_summary
|
||||||
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
||||||
|
|||||||
@ -9,9 +9,9 @@ from typing import Optional
|
|||||||
|
|
||||||
|
|
||||||
from funasr.utils.cli_utils import get_commandline_args
|
from funasr.utils.cli_utils import get_commandline_args
|
||||||
from funasr.text.build_tokenizer import build_tokenizer
|
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||||
from funasr.text.cleaner import TextCleaner
|
from funasr.tokenizer.cleaner import TextCleaner
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.utils.types import str2bool
|
from funasr.utils.types import str2bool
|
||||||
from funasr.utils.types import str_or_none
|
from funasr.utils.types import str_or_none
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
from funasr.text.token_id_converter import TokenIDConverter
|
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||||
from funasr.torch_utils.device_funcs import to_device
|
from funasr.torch_utils.device_funcs import to_device
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from funasr.build_utils.build_model import build_model
|
|||||||
from funasr.build_utils.build_optimizer import build_optimizer
|
from funasr.build_utils.build_optimizer import build_optimizer
|
||||||
from funasr.build_utils.build_scheduler import build_scheduler
|
from funasr.build_utils.build_scheduler import build_scheduler
|
||||||
from funasr.build_utils.build_trainer import build_trainer
|
from funasr.build_utils.build_trainer import build_trainer
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
||||||
from funasr.torch_utils.model_summary import model_summary
|
from funasr.torch_utils.model_summary import model_summary
|
||||||
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
||||||
|
|||||||
@ -1,29 +1,42 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class BatchSampler(torch.utils.data.BatchSampler):
|
class BatchSampler(torch.utils.data.BatchSampler):
|
||||||
|
|
||||||
def __init__(self, dataset=None, args=None, drop_last=True, ):
|
def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
|
||||||
|
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
self.pre_idx = -1
|
self.pre_idx = -1
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size_type = args.batch_size_type
|
|
||||||
self.batch_size = args.batch_size
|
|
||||||
self.sort_size = args.sort_size
|
|
||||||
self.max_length_token = args.max_length_token
|
|
||||||
self.total_samples = len(dataset)
|
self.total_samples = len(dataset)
|
||||||
|
# self.batch_size_type = args.batch_size_type
|
||||||
|
# self.batch_size = args.batch_size
|
||||||
|
# self.sort_size = args.sort_size
|
||||||
|
# self.max_length_token = args.max_length_token
|
||||||
|
self.batch_size_type = batch_size_type
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.sort_size = sort_size
|
||||||
|
self.max_length_token = kwargs.get("max_length_token", 5000)
|
||||||
|
self.shuffle_idx = np.arange(self.total_samples)
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.total_samples
|
return self.total_samples
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
print("in sampler")
|
||||||
|
|
||||||
|
if self.shuffle:
|
||||||
|
np.random.shuffle(self.shuffle_idx)
|
||||||
|
|
||||||
batch = []
|
batch = []
|
||||||
max_token = 0
|
max_token = 0
|
||||||
num_sample = 0
|
num_sample = 0
|
||||||
|
|
||||||
iter_num = (self.total_samples-1) // self.sort_size + 1
|
iter_num = (self.total_samples-1) // self.sort_size + 1
|
||||||
|
print("iter_num: ", iter_num)
|
||||||
for iter in range(self.pre_idx + 1, iter_num):
|
for iter in range(self.pre_idx + 1, iter_num):
|
||||||
datalen_with_index = []
|
datalen_with_index = []
|
||||||
for i in range(self.sort_size):
|
for i in range(self.sort_size):
|
||||||
@ -31,30 +44,31 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
|||||||
if idx >= self.total_samples:
|
if idx >= self.total_samples:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.batch_size_type == "example":
|
idx_map = self.shuffle_idx[idx]
|
||||||
sample_len_cur = 1
|
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||||
else:
|
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
|
||||||
idx_map = self.dataset.shuffle_idx[idx]
|
self.dataset.indexed_dataset[idx_map]["target_len"]
|
||||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
|
||||||
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
|
|
||||||
self.dataset.indexed_dataset[idx_map]["target_len"]
|
|
||||||
|
|
||||||
datalen_with_index.append([idx, sample_len_cur])
|
datalen_with_index.append([idx, sample_len_cur])
|
||||||
|
|
||||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||||
for item in datalen_with_index_sort:
|
for item in datalen_with_index_sort:
|
||||||
idx, sample_len_cur = item
|
idx, sample_len_cur_raw = item
|
||||||
if sample_len_cur > self.max_length_token:
|
if sample_len_cur_raw > self.max_length_token:
|
||||||
continue
|
continue
|
||||||
max_token_cur = max(max_token, sample_len_cur)
|
|
||||||
max_token_padding = (1 + num_sample) * max_token_cur
|
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||||
|
max_token_padding = 1 + num_sample
|
||||||
|
if self.batch_size_type == 'token':
|
||||||
|
max_token_padding *= max_token_cur
|
||||||
if max_token_padding <= self.batch_size:
|
if max_token_padding <= self.batch_size:
|
||||||
batch.append(idx)
|
batch.append(idx)
|
||||||
max_token = max_token_cur
|
max_token = max_token_cur
|
||||||
num_sample += 1
|
num_sample += 1
|
||||||
else:
|
else:
|
||||||
yield batch
|
yield batch
|
||||||
max_token = sample_len_cur
|
|
||||||
num_sample = 1
|
|
||||||
batch = [idx]
|
batch = [idx]
|
||||||
|
max_token = sample_len_cur_raw
|
||||||
|
num_sample = 1
|
||||||
|
|
||||||
|
|
||||||
53
funasr/datasets/dataloader_fn.py
Normal file
53
funasr/datasets/dataloader_fn.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
from funasr.datasets.dataset_jsonl import AudioDataset
|
||||||
|
from funasr.datasets.data_sampler import BatchSampler
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
|
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||||
|
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||||
|
collate_fn = None
|
||||||
|
# collate_fn = collate_fn,
|
||||||
|
|
||||||
|
jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
|
||||||
|
|
||||||
|
frontend = WavFrontend()
|
||||||
|
token_type = 'char'
|
||||||
|
bpemodel = None
|
||||||
|
delimiter = None
|
||||||
|
space_symbol = "<space>"
|
||||||
|
non_linguistic_symbols = None
|
||||||
|
g2p_type = None
|
||||||
|
|
||||||
|
tokenizer = build_tokenizer(
|
||||||
|
token_type=token_type,
|
||||||
|
bpemodel=bpemodel,
|
||||||
|
delimiter=delimiter,
|
||||||
|
space_symbol=space_symbol,
|
||||||
|
non_linguistic_symbols=non_linguistic_symbols,
|
||||||
|
g2p_type=g2p_type,
|
||||||
|
)
|
||||||
|
token_list = ""
|
||||||
|
unk_symbol = "<unk>"
|
||||||
|
|
||||||
|
token_id_converter = TokenIDConverter(
|
||||||
|
token_list=token_list,
|
||||||
|
unk_symbol=unk_symbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
|
||||||
|
batch_sampler = BatchSampler(dataset)
|
||||||
|
dataloader_tr = torch.utils.data.DataLoader(dataset,
|
||||||
|
collate_fn=dataset.collator,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
print(len(dataset))
|
||||||
|
for i in range(3):
|
||||||
|
print(i)
|
||||||
|
for data in dataloader_tr:
|
||||||
|
print(len(data), data)
|
||||||
|
# data_iter = iter(dataloader_tr)
|
||||||
|
# data = next(data_iter)
|
||||||
|
pass
|
||||||
@ -1,12 +1,41 @@
|
|||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import numpy as np
|
||||||
|
import kaldiio
|
||||||
|
import librosa
|
||||||
|
|
||||||
class AudioDatasetJsonl(torch.utils.data.Dataset):
|
|
||||||
|
|
||||||
|
def load_audio(audio_path: str, fs: int=16000):
|
||||||
|
audio = None
|
||||||
|
if audio_path.startswith("oss:"):
|
||||||
|
pass
|
||||||
|
elif audio_path.startswith("odps:"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if ".ark:" in audio_path:
|
||||||
|
audio = kaldiio.load_mat(audio_path)
|
||||||
|
else:
|
||||||
|
audio, fs = librosa.load(audio_path, sr=fs)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def extract_features(data, date_type: str="sound", frontend=None):
|
||||||
|
if date_type == "sound":
|
||||||
|
feat, feats_lens = frontend(data, len(data))
|
||||||
|
feat = feat[0, :, :]
|
||||||
|
else:
|
||||||
|
feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
|
||||||
|
return feat, feats_lens
|
||||||
|
|
||||||
def __init__(self, path, data_parallel_rank=0, data_parallel_size=1):
|
|
||||||
|
|
||||||
|
class IndexedDatasetJsonl(torch.utils.data.Dataset):
|
||||||
|
|
||||||
|
def __init__(self, path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
data_parallel_size = dist.get_world_size()
|
# data_parallel_size = dist.get_world_size()
|
||||||
|
data_parallel_size = 1
|
||||||
contents = []
|
contents = []
|
||||||
with open(path, encoding='utf-8') as fin:
|
with open(path, encoding='utf-8') as fin:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
@ -31,7 +60,8 @@ class AudioDatasetJsonl(torch.utils.data.Dataset):
|
|||||||
self.contents = []
|
self.contents = []
|
||||||
total_num = len(contents)
|
total_num = len(contents)
|
||||||
num_per_rank = total_num // data_parallel_size
|
num_per_rank = total_num // data_parallel_size
|
||||||
rank = dist.get_rank()
|
# rank = dist.get_rank()
|
||||||
|
rank = 0
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
|
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
|
||||||
|
|
||||||
@ -41,3 +71,54 @@ class AudioDatasetJsonl(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
return self.contents[index]
|
return self.contents[index]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, path, frontend=None, tokenizer=None):
|
||||||
|
super().__init__()
|
||||||
|
self.indexed_dataset = IndexedDatasetJsonl(path)
|
||||||
|
self.frontend = frontend.forward
|
||||||
|
self.fs = 16000 if frontend is None else frontend.fs
|
||||||
|
self.data_type = "sound"
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.int_pad_value = -1
|
||||||
|
self.float_pad_value = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indexed_dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
item = self.indexed_dataset[index]
|
||||||
|
source = item["source"]
|
||||||
|
data_src = load_audio(source, fs=self.fs)
|
||||||
|
speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
|
||||||
|
target = item["target"]
|
||||||
|
text = self.tokenizer.encode(target)
|
||||||
|
text_lengths = len(text)
|
||||||
|
text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
|
||||||
|
return {"speech": speech,
|
||||||
|
"speech_lengths": speech_lengths,
|
||||||
|
"text": text,
|
||||||
|
"text_lengths": text_lengths,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def collator(self, samples: list=None):
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
for sample in samples:
|
||||||
|
for key in sample.keys():
|
||||||
|
if key not in outputs:
|
||||||
|
outputs[key] = []
|
||||||
|
outputs[key].append(sample[key])
|
||||||
|
|
||||||
|
for key, data_list in outputs.items():
|
||||||
|
if data_list[0].dtype.kind == "i":
|
||||||
|
pad_value = self.int_pad_value
|
||||||
|
else:
|
||||||
|
pad_value = self.float_pad_value
|
||||||
|
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
|
||||||
|
return samples
|
||||||
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from funasr.datasets.large_datasets.dataset import Dataset
|
from funasr.datasets.large_datasets.dataset import Dataset
|
||||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||||
|
|
||||||
|
|
||||||
def read_symbol_table(symbol_table_file):
|
def read_symbol_table(symbol_table_file):
|
||||||
|
|||||||
@ -13,9 +13,9 @@ import scipy.signal
|
|||||||
import librosa
|
import librosa
|
||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
from funasr.text.build_tokenizer import build_tokenizer
|
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||||
from funasr.text.cleaner import TextCleaner
|
from funasr.tokenizer.cleaner import TextCleaner
|
||||||
from funasr.text.token_id_converter import TokenIDConverter
|
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||||
|
|
||||||
|
|
||||||
class AbsPreprocessor(ABC):
|
class AbsPreprocessor(ABC):
|
||||||
|
|||||||
@ -11,9 +11,9 @@ import numpy as np
|
|||||||
import scipy.signal
|
import scipy.signal
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
from funasr.text.build_tokenizer import build_tokenizer
|
from funasr.tokenizer.build_tokenizer import build_tokenizer
|
||||||
from funasr.text.cleaner import TextCleaner
|
from funasr.tokenizer.cleaner import TextCleaner
|
||||||
from funasr.text.token_id_converter import TokenIDConverter
|
from funasr.tokenizer.token_id_converter import TokenIDConverter
|
||||||
|
|
||||||
|
|
||||||
class AbsPreprocessor(ABC):
|
class AbsPreprocessor(ABC):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
|
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
|
||||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||||
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
||||||
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
|
# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
|
||||||
|
|
||||||
from funasr.models.e2e_vad import E2EVadModel
|
from funasr.models.e2e_vad import E2EVadModel
|
||||||
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
|
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
|
||||||
@ -30,8 +30,8 @@ def get_model(model, export_config=None):
|
|||||||
return [encoder, decoder]
|
return [encoder, decoder]
|
||||||
elif isinstance(model, Paraformer):
|
elif isinstance(model, Paraformer):
|
||||||
return Paraformer_export(model, **export_config)
|
return Paraformer_export(model, **export_config)
|
||||||
elif isinstance(model, Conformer_export):
|
# elif isinstance(model, Conformer_export):
|
||||||
return Conformer_export(model, **export_config)
|
# return Conformer_export(model, **export_config)
|
||||||
elif isinstance(model, E2EVadModel):
|
elif isinstance(model, E2EVadModel):
|
||||||
return E2EVadModel_export(model, **export_config)
|
return E2EVadModel_export(model, **export_config)
|
||||||
elif isinstance(model, PunctuationModel):
|
elif isinstance(model, PunctuationModel):
|
||||||
|
|||||||
@ -1,69 +0,0 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from funasr.export.utils.torch_function import MakePadMask
|
|
||||||
from funasr.export.utils.torch_function import sequence_mask
|
|
||||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
|
||||||
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
|
||||||
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
|
|
||||||
from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
|
|
||||||
|
|
||||||
class Conformer(nn.Module):
|
|
||||||
"""
|
|
||||||
export conformer into onnx format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
max_seq_len=512,
|
|
||||||
feats_dim=560,
|
|
||||||
model_name='model',
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
onnx = False
|
|
||||||
if "onnx" in kwargs:
|
|
||||||
onnx = kwargs["onnx"]
|
|
||||||
if isinstance(model.encoder, ConformerEncoder):
|
|
||||||
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
|
|
||||||
elif isinstance(model.decoder, TransformerDecoder):
|
|
||||||
self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
|
|
||||||
|
|
||||||
self.feats_dim = feats_dim
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
if onnx:
|
|
||||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
|
||||||
else:
|
|
||||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
|
||||||
|
|
||||||
def _export_model(self, model, verbose, path):
|
|
||||||
dummy_input = model.get_dummy_inputs()
|
|
||||||
model_script = model
|
|
||||||
model_path = os.path.join(path, f'{model.model_name}.onnx')
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
torch.onnx.export(
|
|
||||||
model_script,
|
|
||||||
dummy_input,
|
|
||||||
model_path,
|
|
||||||
verbose=verbose,
|
|
||||||
opset_version=14,
|
|
||||||
input_names=model.get_input_names(),
|
|
||||||
output_names=model.get_output_names(),
|
|
||||||
dynamic_axes=model.get_dynamic_axes()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _export_encoder_onnx(self, verbose, path):
|
|
||||||
model_encoder = self.encoder
|
|
||||||
self._export_model(model_encoder, verbose, path)
|
|
||||||
|
|
||||||
def _export_decoder_onnx(self, verbose, path):
|
|
||||||
model_decoder = self.decoder
|
|
||||||
self._export_model(model_decoder, verbose, path)
|
|
||||||
|
|
||||||
def _export_onnx(self, verbose, path):
|
|
||||||
self._export_encoder_onnx(verbose, path)
|
|
||||||
self._export_decoder_onnx(verbose, path)
|
|
||||||
@ -145,9 +145,12 @@ class WavFrontend(AbsFrontend):
|
|||||||
feats_lens.append(feat_length)
|
feats_lens.append(feat_length)
|
||||||
|
|
||||||
feats_lens = torch.as_tensor(feats_lens)
|
feats_lens = torch.as_tensor(feats_lens)
|
||||||
feats_pad = pad_sequence(feats,
|
if batch_size == 1:
|
||||||
batch_first=True,
|
feats_pad = feats[0][None, :, :]
|
||||||
padding_value=0.0)
|
else:
|
||||||
|
feats_pad = pad_sequence(feats,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=0.0)
|
||||||
return feats_pad, feats_lens
|
return feats_pad, feats_lens
|
||||||
|
|
||||||
def forward_fbank(
|
def forward_fbank(
|
||||||
|
|||||||
@ -76,7 +76,7 @@ from funasr.models.specaug.specaug import SpecAug
|
|||||||
from funasr.models.specaug.specaug import SpecAugLFR
|
from funasr.models.specaug.specaug import SpecAugLFR
|
||||||
from funasr.modules.subsampling import Conv1dSubsampling
|
from funasr.modules.subsampling import Conv1dSubsampling
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.models.base_model import FunASRModel
|
from funasr.models.base_model import FunASRModel
|
||||||
from funasr.train.class_choices import ClassChoices
|
from funasr.train.class_choices import ClassChoices
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from funasr.models.preencoder.sinc import LightweightSincConvs
|
|||||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||||
from funasr.models.specaug.specaug import SpecAug
|
from funasr.models.specaug.specaug import SpecAug
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.train.class_choices import ClassChoices
|
from funasr.train.class_choices import ClassChoices
|
||||||
from funasr.train.trainer import Trainer
|
from funasr.train.trainer import Trainer
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from funasr.train.abs_model import LanguageModel
|
|||||||
from funasr.models.seq_rnn_lm import SequentialRNNLM
|
from funasr.models.seq_rnn_lm import SequentialRNNLM
|
||||||
from funasr.models.transformer_lm import TransformerLM
|
from funasr.models.transformer_lm import TransformerLM
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.train.class_choices import ClassChoices
|
from funasr.train.class_choices import ClassChoices
|
||||||
from funasr.train.trainer import Trainer
|
from funasr.train.trainer import Trainer
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from funasr.train.abs_model import PunctuationModel
|
|||||||
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
||||||
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.train.class_choices import ClassChoices
|
from funasr.train.class_choices import ClassChoices
|
||||||
from funasr.train.trainer import Trainer
|
from funasr.train.trainer import Trainer
|
||||||
|
|||||||
@ -71,7 +71,7 @@ from funasr.models.specaug.specaug import SpecAugLFR
|
|||||||
from funasr.models.base_model import FunASRModel
|
from funasr.models.base_model import FunASRModel
|
||||||
from funasr.modules.subsampling import Conv1dSubsampling
|
from funasr.modules.subsampling import Conv1dSubsampling
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.train.class_choices import ClassChoices
|
from funasr.train.class_choices import ClassChoices
|
||||||
from funasr.train.trainer import Trainer
|
from funasr.train.trainer import Trainer
|
||||||
|
|||||||
@ -76,7 +76,7 @@ from funasr.models.specaug.specaug import SpecAug
|
|||||||
from funasr.models.specaug.specaug import SpecAugLFR
|
from funasr.models.specaug.specaug import SpecAugLFR
|
||||||
from funasr.modules.subsampling import Conv1dSubsampling
|
from funasr.modules.subsampling import Conv1dSubsampling
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.models.base_model import FunASRModel
|
from funasr.models.base_model import FunASRModel
|
||||||
from funasr.train.class_choices import ClassChoices
|
from funasr.train.class_choices import ClassChoices
|
||||||
|
|||||||
@ -3,11 +3,11 @@ from typing import Iterable
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||||
from funasr.text.char_tokenizer import CharTokenizer
|
from funasr.tokenizer.char_tokenizer import CharTokenizer
|
||||||
from funasr.text.phoneme_tokenizer import PhonemeTokenizer
|
from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
|
||||||
from funasr.text.sentencepiece_tokenizer import SentencepiecesTokenizer
|
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||||
from funasr.text.word_tokenizer import WordTokenizer
|
from funasr.tokenizer.word_tokenizer import WordTokenizer
|
||||||
|
|
||||||
|
|
||||||
def build_tokenizer(
|
def build_tokenizer(
|
||||||
@ -5,7 +5,7 @@ from typing import Union
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||||
|
|
||||||
|
|
||||||
class CharTokenizer(AbsTokenizer):
|
class CharTokenizer(AbsTokenizer):
|
||||||
@ -10,7 +10,7 @@ import warnings
|
|||||||
# import g2p_en
|
# import g2p_en
|
||||||
import jamo
|
import jamo
|
||||||
|
|
||||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||||
|
|
||||||
|
|
||||||
g2p_choices = [
|
g2p_choices = [
|
||||||
@ -107,7 +107,7 @@ def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> Lis
|
|||||||
List[str]: List of phoneme + prosody symbols.
|
List[str]: List of phoneme + prosody symbols.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from funasr.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
|
>>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
|
||||||
>>> pyopenjtalk_g2p_prosody("こんにちは。")
|
>>> pyopenjtalk_g2p_prosody("こんにちは。")
|
||||||
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
|
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
|
||||||
|
|
||||||
@ -5,7 +5,7 @@ from typing import Union
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||||
|
|
||||||
|
|
||||||
class SentencepiecesTokenizer(AbsTokenizer):
|
class SentencepiecesTokenizer(AbsTokenizer):
|
||||||
@ -5,7 +5,7 @@ from typing import Union
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
from funasr.text.abs_tokenizer import AbsTokenizer
|
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||||
|
|
||||||
|
|
||||||
class WordTokenizer(AbsTokenizer):
|
class WordTokenizer(AbsTokenizer):
|
||||||
Loading…
Reference in New Issue
Block a user