Merge pull request #74 from alibaba-damo-academy/dev_gzf

Dev gzf
This commit is contained in:
Lizerui9926 2023-02-08 19:13:57 +08:00 committed by GitHub
commit bcf6be4c90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 2373 additions and 1 deletions

View File

@ -210,9 +210,18 @@ def inference_launch(**kwargs):
elif mode == "uniasr":
from funasr.bin.asr_inference_uniasr import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "uniasr_vad":
from funasr.bin.asr_inference_uniasr_vad import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer":
from funasr.bin.asr_inference_paraformer import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_vad":
from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_punc":
logging.info("Unknown decoding mode: {}".format(mode))
return None
elif mode == "paraformer_vad_punc":
from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope
return inference_modelscope(**kwargs)

View File

@ -0,0 +1,521 @@
#!/usr/bin/env python3
import json
import argparse
import logging
import sys
import time
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import math
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.bin.punctuation_infer import Text2Punc
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
vad_infer_config: Optional[str] = None,
vad_model_file: Optional[str] = None,
vad_cmvn_file: Optional[str] = None,
time_stamp_writer: bool = False,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
time_stamp_writer=time_stamp_writer,
punc_infer_config=punc_infer_config,
punc_model_file=punc_model_file,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
vad_infer_config: Optional[str] = None,
vad_model_file: Optional[str] = None,
vad_cmvn_file: Optional[str] = None,
time_stamp_writer: bool = True,
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
outputs_dict: Optional[bool] = True,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
vad_model_file=vad_model_file,
vad_cmvn_file=vad_cmvn_file,
device=device,
dtype=dtype,
)
# logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
# 3. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
)
speech2text = Speech2Text(**speech2text_kwargs)
text2punc = None
if punc_model_file is not None:
text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
if output_dir is not None:
writer = DatadirWriter(output_dir)
ibest_writer = writer[f"1best_recog"]
ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=1,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
lfr_factor = 6
# 7 .Start for-loop
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
writer = None
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
vad_results = speech2vadsegment(**batch)
fbanks, vadsegments = vad_results[0], vad_results[1]
for i, segments in enumerate(vadsegments):
result_segments = [["", [], [], ]]
for j, segment_idx in enumerate(segments):
bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
segment = fbanks[:, bed_idx:end_idx, :].to(device)
speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device)
batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0],
"end_time": vadsegments[i][j][1]}
results = speech2text(**batch)
if len(results) < 1:
continue
result_cur = [results[0][:-2]]
if j == 0:
result_segments = result_cur
else:
result_segments = [[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
key = keys[0]
result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
if len(postprocessed_result) == 3:
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
text_postprocessed_punc = ""
if len(word_lists) > 0 and text2punc is not None:
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument("--time_stamp_writer", type=str2bool, default=False)
group.add_argument(
"--frontend_conf",
default=None,
help="",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument(
"--vad_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--vad_model_file",
type=str,
help="VAD model parameter file",
)
group.add_argument(
"--vad_cmvn_file",
type=str,
help="vad, Global cmvn file",
)
group.add_argument(
"--punc_infer_config",
type=str,
help="VAD infer configuration",
)
group.add_argument(
"--punc_model_file",
type=str,
help="VAD model parameter file",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,680 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from pathlib import Path
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
if decoding_mode == "model1":
decoder = asr_model.decoder
else:
decoder = asr_model.decoder2
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
# logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.token_num_relax = token_num_relax
self.decoding_ind = decoding_ind
self.decoding_mode = decoding_mode
self.frontend = frontend
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
feats_raw = feats.clone().to(self.device)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
_, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
if self.decoding_mode == "model1":
predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
else:
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
scama_mask = predictor_outs[4]
pre_token_length = predictor_outs[1]
pre_acoustic_embeds = predictor_outs[0]
maxlen = pre_token_length.sum().item() + self.token_num_relax
minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
ngram_file: Optional[str] = None,
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
ngram_file=ngram_file,
nbest=nbest,
num_workers=num_workers,
token_num_relax=token_num_relax,
decoding_ind=decoding_ind,
decoding_mode=decoding_mode,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
ngram_file: Optional[str] = None,
cmvn_file: Optional[str] = None,
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
ngram_file=ngram_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
token_num_relax=token_num_relax,
decoding_ind=decoding_ind,
decoding_mode=decoding_mode,
)
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
#batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
logging.info(f"Utterance: {key}")
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
group.add_argument("--token_num_relax", type=int, default=1, help="")
group.add_argument("--decoding_ind", type=int, default=0, help="")
group.add_argument("--decoding_mode", type=str, default="model1", help="")
group.add_argument(
"--ctc_weight2",
type=float,
default=0.0,
help="CTC weight in joint decoding",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

50
funasr/export/README.md Normal file
View File

@ -0,0 +1,50 @@
## Environments
funasr 0.1.7
python 3.7
torch 1.11.0
modelscope 1.2.0
## Install modelscope and funasr
The installation is the same as [funasr](../../README.md)
## Export onnx format model
Export model from modelscope
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export" # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export" # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
## Export torchscripts format model
Export model from modelscope
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export" # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export" # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```

View File

View File

@ -0,0 +1,120 @@
from typing import Union, Dict
from pathlib import Path
from typeguard import check_argument_types
import os
import logging
import torch
from funasr.bin.asr_inference_paraformer import Speech2Text
from funasr.export.models import get_model
class ASRModelExportParaformer:
def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
assert check_argument_types()
if cache_dir is None:
cache_dir = Path.home() / "cache" / "export"
self.cache_dir = Path(cache_dir)
self.export_config = dict(
feats_dim=560,
onnx=False,
)
logging.info("output dir: {}".format(self.cache_dir))
self.onnx = onnx
def export(
self,
model: Speech2Text,
tag_name: str = None,
verbose: bool = False,
):
export_dir = self.cache_dir / tag_name.replace(' ', '-')
os.makedirs(export_dir, exist_ok=True)
# export encoder1
self.export_config["model_name"] = "model"
model = get_model(
model,
self.export_config,
)
self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)
else:
self._export_torchscripts(model, verbose, export_dir)
logging.info("output dir: {}".format(export_dir))
def _export_torchscripts(self, model, verbose, path, enc_size=None):
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
dummy_input = model.get_dummy_inputs_txt()
# model_script = torch.jit.script(model)
model_script = torch.jit.trace(model, dummy_input)
model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
def export_from_modelscope(
self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir)
asr_train_config = os.path.join(model_dir, 'config.yaml')
asr_model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, 'cpu'
)
self.export(model, tag_name)
def export_from_local(
self,
tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
model_dir = tag_name
asr_train_config = os.path.join(model_dir, 'config.yaml')
asr_model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, 'cpu'
)
self.export(model, tag_name)
def _export_onnx(self, model, verbose, path, enc_size=None):
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
dummy_input = model.get_dummy_inputs()
# model_script = torch.jit.script(model)
model_script = model #torch.jit.trace(model)
torch.onnx.export(
model_script,
dummy_input,
os.path.join(path, f'{model.model_name}.onnx'),
verbose=verbose,
opset_version=12,
input_names=model.get_input_names(),
output_names=model.get_output_names(),
dynamic_axes=model.get_dynamic_axes()
)
if __name__ == '__main__':
output_dir = "../export"
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
# export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')

View File

@ -0,0 +1,91 @@
# from .ctc import CTC
# from .joint_network import JointNetwork
#
# # encoder
# from espnet2.asr.encoder.rnn_encoder import RNNEncoder as espnetRNNEncoder
# from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder as espnetVGGRNNEncoder
# from espnet2.asr.encoder.contextual_block_transformer_encoder import ContextualBlockTransformerEncoder as espnetContextualTransformer
# from espnet2.asr.encoder.contextual_block_conformer_encoder import ContextualBlockConformerEncoder as espnetContextualConformer
# from espnet2.asr.encoder.transformer_encoder import TransformerEncoder as espnetTransformerEncoder
# from espnet2.asr.encoder.conformer_encoder import ConformerEncoder as espnetConformerEncoder
# from funasr.export.models.encoder.rnn import RNNEncoder
# from funasr.export.models.encoders import TransformerEncoder
# from funasr.export.models.encoders import ConformerEncoder
# from funasr.export.models.encoder.contextual_block_xformer import ContextualBlockXformerEncoder
#
# # decoder
# from espnet2.asr.decoder.rnn_decoder import RNNDecoder as espnetRNNDecoder
# from espnet2.asr.transducer.transducer_decoder import TransducerDecoder as espnetTransducerDecoder
# from funasr.export.models.decoder.rnn import (
# RNNDecoder
# )
# from funasr.export.models.decoders import XformerDecoder
# from funasr.export.models.decoders import TransducerDecoder
#
# # lm
# from espnet2.lm.seq_rnn_lm import SequentialRNNLM as espnetSequentialRNNLM
# from espnet2.lm.transformer_lm import TransformerLM as espnetTransformerLM
# from .language_models.seq_rnn import SequentialRNNLM
# from .language_models.transformer import TransformerLM
#
# # frontend
# from espnet2.asr.frontend.s3prl import S3prlFrontend as espnetS3PRLModel
# from .frontends.s3prl import S3PRLModel
#
# from espnet2.asr.encoder.sanm_encoder import SANMEncoder_tf, SANMEncoderChunkOpt_tf
# from espnet_onnx.export.asr.models.encoders.transformer_sanm import TransformerEncoderSANM_tf
# from espnet2.asr.decoder.transformer_decoder import FsmnDecoderSCAMAOpt_tf
# from funasr.export.models.decoders import XformerDecoderSANM
from funasr.models.e2e_asr_paraformer import Paraformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
def get_model(model, export_config=None):
if isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
else:
raise "The model is not exist!"
# def get_encoder(model, frontend, preencoder, predictor=None, export_config=None):
# if isinstance(model, espnetRNNEncoder) or isinstance(model, espnetVGGRNNEncoder):
# return RNNEncoder(model, frontend, preencoder, **export_config)
# elif isinstance(model, espnetContextualTransformer) or isinstance(model, espnetContextualConformer):
# return ContextualBlockXformerEncoder(model, **export_config)
# elif isinstance(model, espnetTransformerEncoder):
# return TransformerEncoder(model, frontend, preencoder, **export_config)
# elif isinstance(model, espnetConformerEncoder):
# return ConformerEncoder(model, frontend, preencoder, **export_config)
# elif isinstance(model, SANMEncoder_tf) or isinstance(model, SANMEncoderChunkOpt_tf):
# return TransformerEncoderSANM_tf(model, frontend, preencoder, predictor, **export_config)
# else:
# raise "The model is not exist!"
#
# def get_decoder(model, export_config):
# if isinstance(model, espnetRNNDecoder):
# return RNNDecoder(model, **export_config)
# elif isinstance(model, espnetTransducerDecoder):
# return TransducerDecoder(model, **export_config)
# elif isinstance(model, FsmnDecoderSCAMAOpt_tf):
# return XformerDecoderSANM(model, **export_config)
# else:
# return XformerDecoder(model, **export_config)
#
#
# def get_lm(model, export_config):
# if isinstance(model, espnetSequentialRNNLM):
# return SequentialRNNLM(model, **export_config)
# elif isinstance(model, espnetTransformerLM):
# return TransformerLM(model, **export_config)
#
#
# def get_frontend_models(model, export_config):
# if isinstance(model, espnetS3PRLModel):
# return S3PRLModel(model, **export_config)
# else:
# return None
#

View File

View File

@ -0,0 +1,159 @@
import os
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr.modules.attention import MultiHeadedAttentionCrossAtt
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
class ParaformerSANMDecoder(nn.Module):
def __init__(self, model,
max_seq_len=512,
model_name='decoder',
onnx: bool = True,):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
for i, d in enumerate(self.model.decoders):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
self.model.decoders[i] = DecoderLayerSANM_export(d)
if self.model.decoders2 is not None:
for i, d in enumerate(self.model.decoders2):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
self.model.decoders2[i] = DecoderLayerSANM_export(d)
for i, d in enumerate(self.model.decoders3):
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
self.model.decoders3[i] = DecoderLayerSANM_export(d)
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
):
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
x, tgt_mask, memory, memory_mask
)
if self.model.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
x, tgt_mask, memory, memory_mask
)
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
x, tgt_mask, memory, memory_mask
)
x = self.after_norm(x)
x = self.output_layer(x)
return x, ys_in_lens
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
pre_acoustic_embeds = torch.randn(1, 1, enc_size)
cache_num = len(self.model.decoders) + len(self.model.decoders2)
cache = [
torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
for _ in range(cache_num)
]
return (tgt, memory, pre_acoustic_embeds, cache)
def is_optimizable(self):
return True
def get_input_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ ['cache_%d' % i for i in range(cache_num)]
def get_output_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
return ['y'] \
+ ['out_cache_%d' % i for i in range(cache_num)]
def get_dynamic_axes(self):
ret = {
'tgt': {
0: 'tgt_batch',
1: 'tgt_length'
},
'memory': {
0: 'memory_batch',
1: 'memory_length'
},
'pre_acoustic_embeds': {
0: 'acoustic_embeds_batch',
1: 'acoustic_embeds_length',
}
}
cache_num = len(self.model.decoders) + len(self.model.decoders2)
ret.update({
'cache_%d' % d: {
0: 'cache_%d_batch' % d,
2: 'cache_%d_length' % d
}
for d in range(cache_num)
})
return ret
def get_model_config(self, path):
return {
"dec_type": "XformerDecoder",
"model_path": os.path.join(path, f'{self.model_name}.onnx'),
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
"odim": self.model.decoders[0].size
}

View File

@ -0,0 +1,102 @@
import logging
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.models.predictor.cif import CifPredictorV2
from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
class Paraformer(nn.Module):
"""
Author: Speech Lab, Alibaba Group, China
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
if isinstance(model.predictor, CifPredictorV2):
self.predictor = CifPredictorV2_export(model.predictor)
if isinstance(model.decoder, ParaformerSANMDecoder):
self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
self.feats_dim = feats_dim
self.model_name = model_name
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths}
# batch = to_device(batch, device=self.device)
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
pre_token_length = pre_token_length.round().long()
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out = torch.log_softmax(decoder_out, dim=-1)
# sample_ids = decoder_out.argmax(dim=-1)
return decoder_out, pre_token_length
def get_dummy_inputs(self):
speech = torch.randn(2, 30, self.feats_dim)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
return (speech, speech_lengths)
def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
import numpy as np
fbank = np.loadtxt(txt_file)
fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'speech_lengths']
def get_output_names(self):
return ['logits', 'token_num']
def get_dynamic_axes(self):
return {
'speech': {
0: 'batch_size',
1: 'feats_length'
},
'speech_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}

View File

View File

@ -0,0 +1,109 @@
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.modules.attention import MultiHeadedAttentionSANM
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
class SANMEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
if hasattr(model, 'encoders0'):
for i, d in enumerate(self.model.encoders0):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders0[i] = EncoderLayerSANM_export(d)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerSANM_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
mask_4d_bhlt = 1 - mask[:, None, None, :]
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders0(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
encoder_outs = self.model.encoders(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
def get_dummy_inputs(self):
feats = torch.randn(1, 100, self.feats_dim)
return (feats)
def get_input_names(self):
return ['feats']
def get_output_names(self):
return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
def get_dynamic_axes(self):
return {
'feats': {
1: 'feats_length'
},
'encoder_out': {
1: 'enc_out_length'
},
'predictor_weight':{
1: 'pre_out_length'
}
}

View File

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
class DecoderLayerSANM(nn.Module):
def __init__(
self,
model
):
super().__init__()
self.self_attn = model.self_attn
self.src_attn = model.src_attn
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
self.size = model.size
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
residual = tgt
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn is not None:
tgt = self.norm2(tgt)
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + x
if self.src_attn is not None:
residual = x
x = self.norm3(x)
x = residual + self.src_attn(x, memory, memory_mask)
return x, tgt_mask, memory, memory_mask, cache

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
class EncoderLayerSANM(nn.Module):
def __init__(
self,
model,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = model.self_attn
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2
self.size = model.size
def forward(self, x, mask):
residual = x
x = self.norm1(x)
x = self.self_attn(x, mask)
if x.size(2) == residual.size(2):
x = x + residual
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
if x.size(2) == residual.size(2):
x = x + residual
return x, mask

View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class PositionwiseFeedForward(nn.Module):
def __init__(self, model):
super().__init__()
self.w_1 = model.w_1
self.w_2 = model.w_2
self.activation = model.activation
def forward(self, x):
x = self.activation(self.w_1(x))
x = self.w_2(x)
return x
class PositionwiseFeedForwardDecoderSANM(nn.Module):
def __init__(self, model):
super().__init__()
self.w_1 = model.w_1
self.w_2 = model.w_2
self.activation = model.activation
self.norm = model.norm
def forward(self, x):
x = self.activation(self.w_1(x))
x = self.w_2(self.norm(x))
return x

View File

@ -0,0 +1,135 @@
import os
import math
import torch
import torch.nn as nn
class MultiHeadedAttentionSANM(nn.Module):
def __init__(self, model):
super().__init__()
self.d_k = model.d_k
self.h = model.h
self.linear_out = model.linear_out
self.linear_q_k_v = model.linear_q_k_v
self.fsmn_block = model.fsmn_block
self.pad_fn = model.pad_fn
self.attn = None
self.all_head_size = self.h * self.d_k
def forward(self, x, mask):
mask_3d_btd, mask_4d_bhlt = mask
q_h, k_h, v_h, v = self.forward_qkv(x)
fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
q_h = q_h * self.d_k**(-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
return att_outs + fsmn_memory
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward_qkv(self, x):
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q_h = self.transpose_for_scores(q)
k_h = self.transpose_for_scores(k)
v_h = self.transpose_for_scores(v)
return q_h, k_h, v_h, v
def forward_fsmn(self, inputs, mask):
# b, t, d = inputs.size()
# mask = torch.reshape(mask, (b, -1, 1))
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = x * mask
return x
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)
class MultiHeadedAttentionSANMDecoder(nn.Module):
def __init__(self, model):
super().__init__()
self.fsmn_block = model.fsmn_block
self.pad_fn = model.pad_fn
self.kernel_size = model.kernel_size
self.attn = None
def forward(self, inputs, mask, cache=None):
# b, t, d = inputs.size()
# mask = torch.reshape(mask, (b, -1, 1))
inputs = inputs * mask
x = inputs.transpose(1, 2)
if cache is None:
x = self.pad_fn(x)
else:
x = torch.cat((cache[:, :, 1:], x), dim=2)
cache = x
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = x * mask
return x, cache
class MultiHeadedAttentionCrossAtt(nn.Module):
def __init__(self, model):
super().__init__()
self.d_k = model.d_k
self.h = model.h
self.linear_q = model.linear_q
self.linear_k_v = model.linear_k_v
self.linear_out = model.linear_out
self.attn = None
self.all_head_size = self.h * self.d_k
def forward(self, x, memory, memory_mask):
q, k, v = self.forward_qkv(x, memory)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, memory_mask)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward_qkv(self, x, memory):
q = self.linear_q(x)
k_v = self.linear_k_v(memory)
k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
return q, k, v
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
import logging
import numpy as np
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
class CifPredictorV2(nn.Module):
def __init__(self, model):
super().__init__()
self.pad = model.pad
self.cif_conv1d = model.cif_conv1d
self.cif_output = model.cif_output
self.threshold = model.threshold
self.smooth_factor = model.smooth_factor
self.noise_threshold = model.noise_threshold
self.tail_threshold = model.tail_threshold
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
mask = mask.transpose(-1, -2).float()
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
token_num = alphas.sum(dim=-1)
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
@torch.jit.script
def cif(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
list_ls.append(torch.cat([l, pad_l], 0))
return torch.stack(list_ls, 0), fires
def CifPredictorV2_test():
x = torch.rand([2, 21, 2])
x_len = torch.IntTensor([6, 21])
mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
x = x * mask[:, :, None]
predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
predictor_scripts.save('test.pt')
loaded = torch.jit.load('test.pt')
cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
# print(cif_output)
print(predictor_scripts.code)
# predictor = CifPredictorV2(2, 1, 1)
# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
print(cif_output)
def CifPredictorV2_export_test():
x = torch.rand([2, 21, 2])
x_len = torch.IntTensor([6, 21])
mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
x = x * mask[:, :, None]
# predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
predictor = CifPredictorV2(2, 1, 1)
predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
predictor_trace.save('test_trace.pt')
loaded = torch.jit.load('test_trace.pt')
x = torch.rand([3, 30, 2])
x_len = torch.IntTensor([6, 20, 30])
mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
x = x * mask[:, :, None]
cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
print(cif_output)
# print(predictor_trace.code)
# predictor = CifPredictorV2(2, 1, 1)
# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
# print(cif_output)
if __name__ == '__main__':
# CifPredictorV2_test()
CifPredictorV2_export_test()

View File

@ -0,0 +1,20 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(feats_length):
return {'speech': np.zeros((1, feats_length, 560), dtype=np.float32), 'speech_lengths': np.array([feats_length,], dtype=np.int32)}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value.shape))
_run(_get_feed_dict(100))
_run(_get_feed_dict(200))

View File

@ -0,0 +1,17 @@
import torch
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
loaded = torch.jit.load(onnx_path)
x = torch.rand([2, 21, 560])
x_len = torch.IntTensor([6, 21])
res = loaded(x, x_len)
print(res[0].size(), res[1])
x = torch.rand([5, 50, 560])
x_len = torch.IntTensor([6, 21, 10, 30, 50])
res = loaded(x, x_len)
print(res[0].size(), res[1])

View File

View File

@ -0,0 +1,80 @@
from typing import Optional
import torch
import torch.nn as nn
import numpy as np
class MakePadMask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
super().__init__()
if flip:
self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
else:
self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if xs is not None and len(xs.shape) == 3:
if length_dim == 1:
lengths = lengths.unsqueeze(1).expand(
*xs.transpose(1, 2).shape[:2])
else:
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
if maxlen is not None:
m = maxlen
elif xs is not None:
m = xs.shape[-1]
else:
m = torch.max(lengths)
mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
if length_dim == 1:
return mask.transpose(1, 2)
else:
return mask
class sequence_mask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
super().__init__()
def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
if max_seq_len is None:
max_seq_len = lengths.max()
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return input / denom
else:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return torch.div(input, denom, out=out)
def subsequent_mask(size: torch.Tensor):
return torch.ones(size, size).tril()
def MakePadMask_test():
feats_length = torch.tensor([10]).type(torch.long)
mask_fn = MakePadMask()
mask = mask_fn(feats_length)
print(mask)
if __name__ == '__main__':
MakePadMask_test()

View File

@ -293,7 +293,7 @@ class SANMEncoder(AbsEncoder):
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
xs_pad = xs_pad * self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (