Merge pull request #6 from alibaba-damo-academy/dev

update funasr 0.1.3
This commit is contained in:
zhifu gao 2022-12-03 16:55:32 +08:00 committed by GitHub
commit fd278298f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 2712 additions and 7 deletions

View File

@ -65,6 +65,7 @@ for dset in ${test_sets}; do
${decode_cmd} --max-jobs-run "${inference_nj}" JOB=1:"${inference_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.modelscope_infer \
--model_name ${model_name} \
--model_revision ${model_revision} \
--wav_list ${_logdir}/keys.JOB.scp \
--output_file ${_logdir}/text.JOB \
--gpuid_list ${gpuid_list} \

View File

@ -0,0 +1,687 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
from funasr.modules.beam_search.beam_search import BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from modelscope.utils.logger import get_logger
logger = get_logger()
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, device
)
if asr_model.frontend is None and frontend_conf is not None:
frontend = WavFrontend(**frontend_conf)
asr_model.frontend = frontend
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
# TODO(karita): make all scorers batchfied
if batch_size == 1:
non_batch = [
k
for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface)
]
if len(non_batch) == 0:
if streaming:
beam_search.__class__ = BatchBeamSearchOnlineSim
beam_search.set_streaming_config(asr_train_config)
logging.info(
"BatchBeamSearchOnlineSim implementation is selected."
)
else:
beam_search.__class__ = BatchBeamSearch
logging.info("BatchBeamSearch implementation is selected.")
else:
logging.warning(
f"As non-batch scorers {non_batch} are found, "
f"fall back to non-batch implementation."
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray]
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
# data: (Nsamples,) -> (1, Nsamples)
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
# lengths: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
dtype: str,
beam_size: int,
ngpu: int,
seed: int,
ctc_weight: float,
lm_weight: float,
ngram_weight: float,
penalty: float,
nbest: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: list,
audio_lists: Union[List[Any], bytes],
key_file: Optional[str],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
lm_train_config: Optional[str],
lm_file: Optional[str],
word_lm_train_config: Optional[str],
token_type: Optional[str],
bpemodel: Optional[str],
output_dir: Optional[str],
allow_variable_data_keys: bool,
streaming: bool,
frontend_conf: dict = None,
fs: Union[dict, int] = 16000,
**kwargs,
) -> List[Any]:
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1:
device = "cuda"
else:
device = "cpu"
features_type: str = data_path_and_name_and_type[1]
hop_length: int = 160
sr: int = 16000
if isinstance(fs, int):
sr = fs
else:
if 'model_fs' in fs and fs['model_fs'] is not None:
sr = fs['model_fs']
if features_type != 'sound':
frontend_conf = None
if frontend_conf is not None:
if 'hop_length' in frontend_conf:
hop_length = frontend_conf['hop_length']
finish_count = 0
file_count = 1
if isinstance(audio_lists, bytes):
file_count = 1
else:
file_count = len(audio_lists)
if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
mvn_file = data_path_and_name_and_type[2]
mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
frontend_conf['mvn_data'] = mvn_data
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
frontend_conf=frontend_conf,
)
speech2text = Speech2Text(**speech2text_kwargs)
data_path_and_name_and_type_new = [
audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
]
# 3. Build data-iterator
loader = ASRTask.build_streaming_iterator_modelscope(
data_path_and_name_and_type_new,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
sample_rate=fs
)
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
return asr_result_list
def set_parameters(language: str = None,
sample_rate: Union[int, Dict[Any, int]] = None):
if language is not None:
global global_asr_language
global_asr_language = language
if sample_rate is not None:
global global_sample_rate
global_sample_rate = sample_rate
def asr_inference(maxlenratio: float,
minlenratio: float,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
name_and_type: list,
audio_lists: Union[List[Any], bytes],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
nbest: int = 1,
num_workers: int = 1,
log_level: Union[int, str] = 'INFO',
batch_size: int = 1,
dtype: str = 'float32',
seed: int = 0,
key_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
word_lm_file: Optional[str] = None,
ngram_file: Optional[str] = None,
ngram_weight: float = 0.9,
model_tag: Optional[str] = None,
token_type: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
transducer_conf: Optional[dict] = None,
streaming: bool = False,
frontend_conf: dict = None,
fs: Union[dict, int] = None,
lang: Optional[str] = None,
outputdir: Optional[str] = None):
if lang is not None:
global global_asr_language
global_asr_language = lang
if fs is not None:
global global_sample_rate
global_sample_rate = fs
# force use CPU if data type is bytes
if isinstance(audio_lists, bytes):
num_workers = 0
ngpu = 0
return inference(output_dir=outputdir,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
dtype=dtype,
beam_size=beam_size,
ngpu=ngpu,
seed=seed,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
num_workers=num_workers,
log_level=log_level,
data_path_and_name_and_type=name_and_type,
audio_lists=audio_lists,
key_file=key_file,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
word_lm_train_config=word_lm_train_config,
word_lm_file=word_lm_file,
ngram_file=ngram_file,
model_tag=model_tag,
token_type=token_type,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
transducer_conf=transducer_conf,
streaming=streaming,
frontend_conf=frontend_conf)
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--audio_lists", type=list,
default=[{'key':'EdevDEWdIYQ_0021',
'file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,686 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
import time
from pathlib import Path
from typing import Any
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import List
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from modelscope.utils.logger import get_logger
logger = get_logger()
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, device
)
if asr_model.frontend is None and frontend_conf is not None:
frontend = WavFrontend(**frontend_conf)
asr_model.frontend = frontend
asr_model.to(dtype=getattr(torch, dtype)).eval()
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray]
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
# data: (Nsamples,) -> (1, Nsamples)
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
# lengths: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device)
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
nbest_hyps = self.beam_search(
x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp, speech.size(1), lfr_factor))
# assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
dtype: str,
beam_size: int,
ngpu: int,
seed: int,
ctc_weight: float,
lm_weight: float,
ngram_weight: float,
penalty: float,
nbest: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: list,
audio_lists: Union[List[Any], bytes],
key_file: Optional[str],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
lm_train_config: Optional[str],
lm_file: Optional[str],
word_lm_train_config: Optional[str],
model_tag: Optional[str],
token_type: Optional[str],
bpemodel: Optional[str],
output_dir: Optional[str],
allow_variable_data_keys: bool,
frontend_conf: dict = None,
fs: Union[dict, int] = 16000,
**kwargs,
) -> List[Any]:
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1:
device = "cuda"
else:
device = "cpu"
# data_path_and_name_and_type = data_path_and_name_and_type[0]
features_type: str = data_path_and_name_and_type[1]
hop_length: int = 160
sr: int = 16000
if isinstance(fs, int):
sr = fs
else:
if 'model_fs' in fs and fs['model_fs'] is not None:
sr = fs['model_fs']
if features_type != 'sound':
frontend_conf = None
if frontend_conf is not None:
if 'hop_length' in frontend_conf:
hop_length = frontend_conf['hop_length']
finish_count = 0
file_count = 1
if isinstance(audio_lists, bytes):
file_count = 1
else:
file_count = len(audio_lists)
if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
mvn_file = data_path_and_name_and_type[2]
mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
frontend_conf['mvn_data'] = mvn_data
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
frontend_conf=frontend_conf,
)
speech2text = Speech2Text(**speech2text_kwargs)
data_path_and_name_and_type_new = [
audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
]
# 3. Build data-iterator
loader = ASRTask.build_streaming_iterator_modelscope(
data_path_and_name_and_type_new,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
sample_rate=fs
)
forward_time_total = 0.0
length_total = 0.0
asr_result_list = []
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
logging.info("decoding, utt_id: {}".format(keys))
# N-best list of (text, token, token_int, hyp_object)
try:
time_beg = time.time()
results = speech2text(**batch)
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]
length = results[0][-2]
results = [results[0][:-2]]
forward_time_total += forward_time
length_total += length
logging.info(
"decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
format(length, forward_time, 100 * forward_time / (length * lfr_factor)))
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
logging.info("decoding, predictions: {}".format(text))
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
if features_type == 'sound':
# data format is wav
length_total_seconds = length_total / sr
length_total_bytes = length_total * 2
else:
# data format is kaldi_ark
length_total_seconds = length_total * hop_length / sr
length_total_bytes = length_total * hop_length * 2
logger.info(
header_colors + # noqa: *
'decoding, feature length total: {}bytes, forward_time total: {:.4f}s, rtf avg: {:.4f}'
.format(length_total_bytes, forward_time_total, forward_time_total /
length_total_seconds) + end_colors)
return asr_result_list
def set_parameters(language: str = None,
sample_rate: Union[int, Dict[Any, int]] = None):
if language is not None:
global global_asr_language
global_asr_language = language
if sample_rate is not None:
global global_sample_rate
global_sample_rate = sample_rate
def asr_inference(maxlenratio: float,
minlenratio: float,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
name_and_type: list,
audio_lists: Union[List[Any], bytes],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
nbest: int = 1,
num_workers: int = 1,
log_level: Union[int, str] = 'INFO',
batch_size: int = 1,
dtype: str = 'float32',
seed: int = 0,
key_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
word_lm_file: Optional[str] = None,
ngram_file: Optional[str] = None,
ngram_weight: float = 0.9,
model_tag: Optional[str] = None,
token_type: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
transducer_conf: Optional[dict] = None,
streaming: bool = False,
frontend_conf: dict = None,
fs: Union[dict, int] = None,
lang: Optional[str] = None,
outputdir: Optional[str] = None):
if lang is not None:
global global_asr_language
global_asr_language = lang
if fs is not None:
global global_sample_rate
global_sample_rate = fs
# force use CPU if data type is bytes
if isinstance(audio_lists, bytes):
num_workers = 0
ngpu = 0
return inference(output_dir=outputdir,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
dtype=dtype,
beam_size=beam_size,
ngpu=ngpu,
seed=seed,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
num_workers=num_workers,
log_level=log_level,
data_path_and_name_and_type=name_and_type,
audio_lists=audio_lists,
key_file=key_file,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
word_lm_train_config=word_lm_train_config,
word_lm_file=word_lm_file,
ngram_file=ngram_file,
model_tag=model_tag,
token_type=token_type,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
transducer_conf=transducer_conf,
streaming=streaming,
frontend_conf=frontend_conf)
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=True,
action="append",
)
group.add_argument("--audio_lists", type=list, default=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument(
"--asr_model_config",
default=None,
help="",
)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@ -15,6 +15,10 @@ if __name__ == '__main__':
type=str,
default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model name in modelscope")
parser.add_argument("--model_revision",
type=str,
default="v1.0.3",
help="model revision in modelscope")
parser.add_argument("--local_model_path",
type=str,
default=None,
@ -62,7 +66,8 @@ if __name__ == '__main__':
if args.local_model_path is None:
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/{}".format(args.model_name))
model="damo/{}".format(args.model_name),
model_revision=args.model_revision)
else:
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,

View File

@ -0,0 +1,349 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Iterable dataset module."""
import copy
from io import StringIO
from pathlib import Path
from typing import Callable, Collection, Dict, Iterator, Tuple, Union
import kaldiio
import numpy as np
import soundfile
import torch
from funasr.datasets.dataset import ESPnetDataset
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
from funasr.utils import wav_utils
def load_kaldi(input):
retval = kaldiio.load_mat(input)
if isinstance(retval, tuple):
assert len(retval) == 2, len(retval)
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
# sound scp case
rate, array = retval
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
# Extended ark format case
array, rate = retval
else:
raise RuntimeError(
f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
# Multichannel wave fie
# array: (NSample, Channel) or (Nsample)
else:
# Normal ark case
assert isinstance(retval, np.ndarray), type(retval)
array = retval
return array
DATA_TYPES = {
'sound':
lambda x: soundfile.read(x)[0],
'kaldi_ark':
load_kaldi,
'npy':
np.load,
'text_int':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
'csv_int':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
'text_float':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
),
'csv_float':
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
),
'text':
lambda x: x,
}
class IterableESPnetDatasetModelScope(IterableDataset):
"""Pytorch Dataset class for ESPNet.
Examples:
>>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
... ('token_int', 'output', 'text_int')],
... )
>>> for uid, data in dataset:
... data
{'input': per_utt_array, 'output': per_utt_array}
"""
def __init__(self,
path_name_type_list: Collection[Tuple[any, str, str]],
preprocess: Callable[[str, Dict[str, np.ndarray]],
Dict[str, np.ndarray]] = None,
float_dtype: str = 'float32',
int_dtype: str = 'long',
key_file: str = None,
sample_rate: Union[dict, int] = 16000):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"')
self.preprocess = preprocess
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.key_file = key_file
self.sample_rate = sample_rate
self.debug_info = {}
non_iterable_list = []
self.path_name_type_list = []
path_list = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]
if name in self.debug_info:
raise RuntimeError(f'"{name}" is duplicated for data-key')
self.debug_info[name] = path_list, _type
# for path, name, _type in path_name_type_list:
for path in path_list:
self.path_name_type_list.append((path, name, _type))
if len(non_iterable_list) != 0:
# Some types doesn't support iterable mode
self.non_iterable_dataset = ESPnetDataset(
path_name_type_list=non_iterable_list,
preprocess=preprocess,
float_dtype=float_dtype,
int_dtype=int_dtype,
)
else:
self.non_iterable_dataset = None
self.apply_utt2category = False
def has_name(self, name) -> bool:
return name in self.debug_info
def names(self) -> Tuple[str, ...]:
return tuple(self.debug_info)
def __repr__(self):
_mes = self.__class__.__name__
_mes += '('
for name, (path, _type) in self.debug_info.items():
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
_mes += f'\n preprocess: {self.preprocess})'
return _mes
def __iter__(
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
torch.set_printoptions(profile='default')
count = len(self.path_name_type_list)
for idx in range(count):
# 2. Load the entry from each line and create a dict
data = {}
# 2.a. Load data streamingly
# value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
value = self.path_name_type_list[idx][0]['file']
uid = self.path_name_type_list[idx][0]['key']
# name: speech
name = self.path_name_type_list[idx][1]
_type = self.path_name_type_list[idx][2]
func = DATA_TYPES[_type]
array = func(value)
# 2.b. audio resample
if _type == 'sound':
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(self.sample_rate, int):
model_sr = self.sample_rate
else:
if 'audio_sr' in self.sample_rate:
audio_sr = self.sample_rate['audio_sr']
if 'model_sr' in self.sample_rate:
model_sr = self.sample_rate['model_sr']
array = wav_utils.torch_resample(array, audio_sr, model_sr)
# array: [ 1.25122070e-03 ... ]
data[name] = array
# 3. [Option] Apply preprocessing
# e.g. espnet2.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
# 4. Force data-precision
for name in data:
# value is np.ndarray data
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f'All values must be converted to np.ndarray object '
f'by preprocessing, but "{name}" is still {type(value)}.'
)
# Cast to desired type
if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(
f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data
if count == 0:
raise RuntimeError('No iteration')
class IterableESPnetBytesModelScope(IterableDataset):
"""Pytorch audio bytes class for ESPNet.
Examples:
>>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
... ('token_int', 'output', 'text_int')],
... )
>>> for uid, data in dataset:
... data
{'input': per_utt_array, 'output': per_utt_array}
"""
def __init__(self,
path_name_type_list: Collection[Tuple[any, str, str]],
preprocess: Callable[[str, Dict[str, np.ndarray]],
Dict[str, np.ndarray]] = None,
float_dtype: str = 'float32',
int_dtype: str = 'long',
key_file: str = None,
sample_rate: Union[dict, int] = 16000):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"')
self.preprocess = preprocess
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.key_file = key_file
self.sample_rate = sample_rate
self.debug_info = {}
non_iterable_list = []
self.path_name_type_list = []
audio_data = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]
if name in self.debug_info:
raise RuntimeError(f'"{name}" is duplicated for data-key')
self.debug_info[name] = audio_data, _type
self.path_name_type_list.append((audio_data, name, _type))
if len(non_iterable_list) != 0:
# Some types doesn't support iterable mode
self.non_iterable_dataset = ESPnetDataset(
path_name_type_list=non_iterable_list,
preprocess=preprocess,
float_dtype=float_dtype,
int_dtype=int_dtype,
)
else:
self.non_iterable_dataset = None
self.apply_utt2category = False
if float_dtype == 'float32':
self.np_dtype = np.float32
def has_name(self, name) -> bool:
return name in self.debug_info
def names(self) -> Tuple[str, ...]:
return tuple(self.debug_info)
def __repr__(self):
_mes = self.__class__.__name__
_mes += '('
for name, (path, _type) in self.debug_info.items():
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
_mes += f'\n preprocess: {self.preprocess})'
return _mes
def __iter__(
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
torch.set_printoptions(profile='default')
# 2. Load the entry from each line and create a dict
data = {}
# 2.a. Load data streamingly
value = self.path_name_type_list[0][0]
uid = 'pcm_data'
# name: speech
name = self.path_name_type_list[0][1]
_type = self.path_name_type_list[0][2]
func = DATA_TYPES[_type]
# array: [ 1.25122070e-03 ... ]
# data[name] = np.frombuffer(value, dtype=self.np_dtype)
# 2.b. byte(PCM16) to float32
middle_data = np.frombuffer(value, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
dtype=self.np_dtype)
# 2.c. audio resample
if _type == 'sound':
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(self.sample_rate, int):
model_sr = self.sample_rate
else:
if 'audio_sr' in self.sample_rate:
audio_sr = self.sample_rate['audio_sr']
if 'model_sr' in self.sample_rate:
model_sr = self.sample_rate['model_sr']
array = wav_utils.torch_resample(array, audio_sr, model_sr)
data[name] = array
# 3. [Option] Apply preprocessing
# e.g. espnet2.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
# 4. Force data-precision
for name in data:
# value is np.ndarray data
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
f'All values must be converted to np.ndarray object '
f'by preprocessing, but "{name}" is still {type(value)}.')
# Cast to desired type
if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
raise NotImplementedError(
f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data

View File

@ -330,9 +330,10 @@ class Paraformer(AbsESPnetModel):
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
decoder_out, _ = self.decoder(
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@ -553,7 +554,6 @@ class ParaformerBert(Paraformer):
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
joint_network: Optional[torch.nn.Module],
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
@ -590,7 +590,6 @@ class ParaformerBert(Paraformer):
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
joint_network=joint_network,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,

View File

@ -0,0 +1,155 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
import copy
from typing import Optional, Tuple, Union
import humanfriendly
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.frontends.frontend import Frontend
from typeguard import check_argument_types
def apply_cmvn(inputs, mvn): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
meams = np.tile(mvn[0:1, :dim], (frame, 1))
vars = np.tile(mvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(meams).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = 400,
hop_length: int = 160,
window: Optional[str] = 'hamming',
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
lfr_m: int = 1,
lfr_n: int = 1,
htk: bool = False,
mvn_data=None,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fs = fs
self.mvn_data = mvn_data
self.lfr_m = lfr_m
self.lfr_n = lfr_n
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = 'default'
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
sample_frequency = self.fs
num_mel_bins = self.n_mels
frame_length = self.win_length * 1000 / sample_frequency
frame_shift = self.hop_length * 1000 / sample_frequency
waveform = input * (1 << 15)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=1.0,
energy_floor=0.0,
window_type=self.window,
sample_frequency=sample_frequency)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.mvn_data is not None:
mat = apply_cmvn(mat, self.mvn_data)
input_feats = mat[None, :]
feats_lens = torch.randn(1)
feats_lens.fill_(input_feats.shape[1])
return input_feats, feats_lens

View File

@ -4,7 +4,7 @@ from torch import nn
from funasr.modules.nets_utils import make_pad_mask
class CifPredictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
super(CifPredictor, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)

View File

@ -38,6 +38,7 @@ from funasr.datasets.dataset import AbsDataset
from funasr.datasets.dataset import DATA_TYPES
from funasr.datasets.dataset import ESPnetDataset
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
@ -1026,7 +1027,7 @@ class AbsTask(ABC):
@classmethod
def check_task_requirements(
cls,
dataset: Union[AbsDataset, IterableESPnetDataset],
dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
allow_variable_data_keys: bool,
train: bool,
inference: bool = False,
@ -1748,6 +1749,64 @@ class AbsTask(ABC):
**kwargs,
)
@classmethod
def build_streaming_iterator_modelscope(
cls,
data_path_and_name_and_type,
preprocess_fn,
collate_fn,
key_file: str = None,
batch_size: int = 1,
dtype: str = np.float32,
num_workers: int = 1,
allow_variable_data_keys: bool = False,
ngpu: int = 0,
inference: bool = False,
sample_rate: Union[dict, int] = 16000
) -> DataLoader:
"""Build DataLoader using iterable dataset"""
assert check_argument_types()
# For backward compatibility for pytorch DataLoader
if collate_fn is not None:
kwargs = dict(collate_fn=collate_fn)
else:
kwargs = {}
audio_data = data_path_and_name_and_type[0]
if isinstance(audio_data, bytes):
dataset = IterableESPnetBytesModelScope(
data_path_and_name_and_type,
float_dtype=dtype,
preprocess=preprocess_fn,
key_file=key_file,
sample_rate=sample_rate
)
else:
dataset = IterableESPnetDatasetModelScope(
data_path_and_name_and_type,
float_dtype=dtype,
preprocess=preprocess_fn,
key_file=key_file,
sample_rate=sample_rate
)
if dataset.apply_utt2category:
kwargs.update(batch_size=1)
else:
kwargs.update(batch_size=batch_size)
cls.check_task_requirements(dataset,
allow_variable_data_keys,
train=False,
inference=inference)
return DataLoader(
dataset=dataset,
pin_memory=ngpu > 0,
num_workers=num_workers,
**kwargs,
)
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(

View File

@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import ssl
import nltk
# mkdir nltk_data dir if not exist
try:
nltk.data.find('.')
except LookupError:
dir_list = nltk.data.path
for dir_item in dir_list:
if not os.path.exists(dir_item):
os.mkdir(dir_item)
if os.path.exists(dir_item):
break
# download one package if nltk_data not exist
try:
nltk.data.find('.')
except: # noqa: *
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
# deploy taggers/averaged_perceptron_tagger
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except: # noqa: *
data_dir = nltk.data.find('.')
target_dir = os.path.join(data_dir, 'taggers')
if not os.path.exists(target_dir):
os.mkdir(target_dir)
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
'averaged_perceptron_tagger.zip')
shutil.copyfile(src_file,
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'))
shutil._unpack_zipfile(
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir)
# deploy corpora/cmudict
try:
nltk.data.find('corpora/cmudict')
except: # noqa: *
data_dir = nltk.data.find('.')
target_dir = os.path.join(data_dir, 'corpora')
if not os.path.exists(target_dir):
os.mkdir(target_dir)
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
'cmudict.zip')
shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip'))
shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir)
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except: # noqa: *
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('averaged_perceptron_tagger',
halt_on_error=False,
raise_on_error=True)
try:
nltk.data.find('corpora/cmudict')
except: # noqa: *
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)

327
funasr/utils/asr_utils.py Normal file
View File

@ -0,0 +1,327 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import struct
from typing import Any, Dict, List, Union
import librosa
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
logger = get_logger()
green_color = '\033[1;32m'
red_color = '\033[0;31;40m'
yellow_color = '\033[0;33;40m'
end_color = '\033[0m'
global_asr_language = 'zh-cn'
def get_version():
return float(pkg_resources.get_distribution('easyasr').version)
def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
r_audio_fs = None
if audio_format == 'wav':
r_audio_fs = get_sr_from_wav(audio_in)
elif audio_format == 'pcm' and isinstance(audio_in, bytes):
r_audio_fs = get_sr_from_bytes(audio_in)
return r_audio_fs
def type_checking(audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None):
r_recog_type = recog_type
r_audio_format = audio_format
r_wav_path = audio_in
if isinstance(audio_in, str):
assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
elif isinstance(audio_in, bytes):
assert len(audio_in) > 0, 'audio in is empty'
r_audio_format = 'pcm'
r_recog_type = 'wav'
if r_recog_type is None:
# audio_in is wav, recog_type is wav_file
if os.path.isfile(audio_in):
if audio_in.endswith('.wav') or audio_in.endswith('.WAV'):
r_recog_type = 'wav'
r_audio_format = 'wav'
# recog_type is datasets_file
elif os.path.isdir(audio_in):
dir_name = os.path.basename(audio_in)
if 'test' in dir_name:
r_recog_type = 'test'
elif 'dev' in dir_name:
r_recog_type = 'dev'
elif 'train' in dir_name:
r_recog_type = 'train'
if r_audio_format is None:
if find_file_by_ends(audio_in, '.ark'):
r_audio_format = 'kaldi_ark'
elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
audio_in, '.WAV'):
r_audio_format = 'wav'
elif find_file_by_ends(audio_in, '.records'):
r_audio_format = 'tfrecord'
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
# datasets with kaldi_ark file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
# datasets with tensorflow records file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'wav' and r_recog_type != 'wav':
# datasets with waveform files
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
return r_recog_type, r_audio_format, r_wav_path
def get_sr_from_bytes(wav: bytes):
sr = None
data = wav
if len(data) > 44:
try:
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields[
'Subchunk1ID'] == 'fmt ':
header_fields['SampleRate'] = struct.unpack('<I',
data[24:28])[0]
sr = header_fields['SampleRate']
except Exception:
# no treatment
pass
else:
logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
return sr
def get_sr_from_wav(fname: str):
fs = None
if os.path.isfile(fname):
audio, fs = librosa.load(fname, sr=None)
return fs
elif os.path.isdir(fname):
dir_files = os.listdir(fname)
for file in dir_files:
file_path = os.path.join(fname, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
fs = get_sr_from_wav(file_path)
elif os.path.isdir(file_path):
fs = get_sr_from_wav(file_path)
if fs is not None:
break
return fs
def find_file_by_ends(dir_path: str, ends: str):
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith(ends):
return True
elif os.path.isdir(file_path):
if find_file_by_ends(file_path, ends):
return True
return False
def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
elif os.path.isdir(file_path):
recursion_dir_all_wav(wav_list, file_path)
return wav_list
def set_parameters(language: str = None):
if language is not None:
global global_asr_language
global_asr_language = language
def compute_wer(hyp_list: List[Any],
ref_list: List[Any],
lang: str = None) -> Dict[str, Any]:
assert len(hyp_list) > 0, 'hyp list is empty'
assert len(ref_list) > 0, 'ref list is empty'
if lang is not None:
global global_asr_language
global_asr_language = lang
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
for h_item in hyp_list:
for r_item in ref_list:
if h_item['key'] == r_item['key']:
out_item = compute_wer_by_line(h_item['value'],
r_item['value'],
global_asr_language)
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
print_wrong_sentence(key=h_item['key'],
hyp=h_item['value'],
ref=r_item['value'])
else:
print_correct_sentence(key=h_item['key'],
hyp=h_item['value'],
ref=r_item['value'])
break
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
return rst
def compute_wer_by_line(hyp: List[str],
ref: List[str],
lang: str = 'zh-cn') -> Dict[str, Any]:
if lang != 'zh-cn':
hyp = hyp.split()
ref = ref.split()
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_wrong_sentence(key: str, hyp: str, ref: str):
space = len(key)
print(key + yellow_color + ' ref: ' + ref)
print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
def print_correct_sentence(key: str, hyp: str, ref: str):
space = len(key)
print(key + yellow_color + ' ref: ' + ref)
print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
def print_progress(percent):
if percent > 1:
percent = 1
res = int(50 * percent) * '#'
print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')

View File

@ -0,0 +1,174 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
from typing import Any, List, Union
def isChinese(ch: str):
if '\u4e00' <= ch <= '\u9fff':
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
table = str.maketrans('', '', string.punctuation)
for i in word:
cur = i.translate(table)
cur = cur.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
table = str.maketrans('', '', string.punctuation)
for i in word:
cur = i.translate(table)
cur = cur.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False:
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
def abbr_dispose(words: List[Any]) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
if num + 1 < words_size and words[
num + 1] == ' ' and num + 2 < words_size and len(
words[num +
2]) == 1 and words[num +
2].encode('utf-8').isalpha():
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == ' ':
num += 1
if num < words_size and len(
words[num]) == 1 and words[num].encode(
'utf-8').isalpha():
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
word_lists.append(words[num].upper())
num += 1
while num < words_size:
if num in abbr_end:
word_lists.append(words[num].upper())
last_num = num
break
else:
if words[num].encode('utf-8').isalpha():
word_lists.append(words[num].upper())
num += 1
else:
word_lists.append(words[num])
return word_lists
def sentence_postprocess(words: List[Any]):
middle_lists = []
word_lists = []
word_item = ''
# wash words lists
for i in words:
word = ''
if isinstance(i, str):
word = i
else:
word = i.decode('utf-8')
if word in ['<s>', '</s>', '<unk>']:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for ch in middle_lists:
word_lists.append(ch.replace(' ', ''))
# all alpha characters
elif isAllAlpha(middle_lists):
for ch in middle_lists:
word = ''
if '@@' in ch:
word = ch.replace('@@', '')
word_item += word
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
# mix characters
else:
alpha_blank = False
for ch in middle_lists:
word = ''
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
elif '@@' in ch:
word = ch.replace('@@', '')
word_item += word
alpha_blank = False
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
alpha_blank = True
else:
raise ValueError('invalid character: {}'.format(ch))
word_lists = abbr_dispose(word_lists)
sentence = ''.join(word_lists).strip()
return sentence

178
funasr/utils/wav_utils.py Normal file
View File

@ -0,0 +1,178 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from typing import Any, Dict, Union
import kaldiio
import librosa
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def ndarray_resample(audio_in: np.ndarray,
fs_in: int = 16000,
fs_out: int = 16000) -> np.ndarray:
audio_out = audio_in
if fs_in != fs_out:
audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
return audio_out
def torch_resample(audio_in: torch.Tensor,
fs_in: int = 16000,
fs_out: int = 16000) -> torch.Tensor:
audio_out = audio_in
if fs_in != fs_out:
audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
new_freq=fs_out)(audio_in)
return audio_out
def extract_CMVN_featrures(mvn_file):
"""
extract CMVN from cmvn.ark
"""
if not os.path.exists(mvn_file):
return None
try:
cmvn = kaldiio.load_mat(mvn_file)
means = []
variance = []
for i in range(cmvn.shape[1] - 1):
means.append(float(cmvn[0][i]))
count = float(cmvn[0][-1])
for i in range(cmvn.shape[1] - 1):
variance.append(float(cmvn[1][i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
except Exception:
cmvn = extract_CMVN_features_txt(mvn_file)
return cmvn
def extract_CMVN_features_txt(mvn_file): # noqa
with open(mvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
add_shift_list = []
rescale_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
add_shift_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
rescale_list = list(rescale_line)
continue
add_shift_list_f = [float(s) for s in add_shift_list]
rescale_list_f = [float(s) for s in rescale_list]
cmvn = np.array([add_shift_list_f, rescale_list_f])
return cmvn
def build_LFR_features(inputs, m=7, n=6): # noqa
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
# LFR_inputs_batch = []
# for inputs in inputs_batch:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (m - 1) // 2
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else: # process last LFR frame
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
is_pcm=False,
fs: Union[int, Dict[Any, int]] = 16000):
audio_sr: int = 16000
model_sr: int = 16000
if isinstance(fs, int):
model_sr = fs
audio_sr = fs
else:
model_sr = fs['model_fs']
audio_sr = fs['audio_fs']
if is_pcm is True:
# byte(PCM16) to float32, and resample
value = wav_file
middle_data = np.frombuffer(value, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
waveform = np.frombuffer(
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
waveform = ndarray_resample(waveform, audio_sr, model_sr)
waveform = torch.from_numpy(waveform.reshape(1, -1))
else:
# load pcm from wav, and resample
waveform, audio_sr = torchaudio.load(wav_file)
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=model_sr)
input_feats = mat
return input_feats

View File

@ -1 +1 @@
0.1.0
0.1.3