mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #6 from alibaba-damo-academy/dev
update funasr 0.1.3
This commit is contained in:
commit
fd278298f5
@ -65,6 +65,7 @@ for dset in ${test_sets}; do
|
||||
${decode_cmd} --max-jobs-run "${inference_nj}" JOB=1:"${inference_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.modelscope_infer \
|
||||
--model_name ${model_name} \
|
||||
--model_revision ${model_revision} \
|
||||
--wav_list ${_logdir}/keys.JOB.scp \
|
||||
--output_file ${_logdir}/text.JOB \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
|
||||
687
funasr/bin/asr_inference_modelscope.py
Executable file
687
funasr/bin/asr_inference_modelscope.py
Executable file
@ -0,0 +1,687 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
|
||||
from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
|
||||
from funasr.modules.beam_search.beam_search import BeamSearch
|
||||
from funasr.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.asr import ASRTask
|
||||
from funasr.tasks.lm import LMTask
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
|
||||
global_asr_language: str = 'zh-cn'
|
||||
global_sample_rate: Union[int, Dict[Any, int]] = {
|
||||
'audio_fs': 16000,
|
||||
'model_fs': 16000
|
||||
}
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, device
|
||||
)
|
||||
if asr_model.frontend is None and frontend_conf is not None:
|
||||
frontend = WavFrontend(**frontend_conf)
|
||||
asr_model.frontend = frontend
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
decoder = asr_model.decoder
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, device
|
||||
)
|
||||
scorers["lm"] = lm.lm
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
# transducer is not supported now
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
# TODO(karita): make all scorers batchfied
|
||||
if batch_size == 1:
|
||||
non_batch = [
|
||||
k
|
||||
for k, v in beam_search.full_scorers.items()
|
||||
if not isinstance(v, BatchScorerInterface)
|
||||
]
|
||||
if len(non_batch) == 0:
|
||||
if streaming:
|
||||
beam_search.__class__ = BatchBeamSearchOnlineSim
|
||||
beam_search.set_streaming_config(asr_train_config)
|
||||
logging.info(
|
||||
"BatchBeamSearchOnlineSim implementation is selected."
|
||||
)
|
||||
else:
|
||||
beam_search.__class__ = BatchBeamSearch
|
||||
logging.info("BatchBeamSearch implementation is selected.")
|
||||
else:
|
||||
logging.warning(
|
||||
f"As non-batch scorers {non_batch} are found, "
|
||||
f"fall back to non-batch implementation."
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
logging.info(f"Beam_search: {beam_search}")
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.beam_search = beam_search
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray]
|
||||
) -> List[
|
||||
Tuple[
|
||||
Optional[str],
|
||||
List[str],
|
||||
List[int],
|
||||
Union[Hypothesis],
|
||||
]
|
||||
]:
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
# data: (Nsamples,) -> (1, Nsamples)
|
||||
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
|
||||
# lengths: (1,)
|
||||
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
||||
batch = {"speech": speech, "speech_lengths": lengths}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
enc, _ = self.asr_model.encode(**batch)
|
||||
if isinstance(enc, tuple):
|
||||
enc = enc[0]
|
||||
assert len(enc) == 1, len(enc)
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
results.append((text, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
def inference(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
dtype: str,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
seed: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
ngram_weight: float,
|
||||
penalty: float,
|
||||
nbest: int,
|
||||
num_workers: int,
|
||||
log_level: Union[int, str],
|
||||
data_path_and_name_and_type: list,
|
||||
audio_lists: Union[List[Any], bytes],
|
||||
key_file: Optional[str],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
lm_train_config: Optional[str],
|
||||
lm_file: Optional[str],
|
||||
word_lm_train_config: Optional[str],
|
||||
token_type: Optional[str],
|
||||
bpemodel: Optional[str],
|
||||
output_dir: Optional[str],
|
||||
allow_variable_data_keys: bool,
|
||||
streaming: bool,
|
||||
frontend_conf: dict = None,
|
||||
fs: Union[dict, int] = 16000,
|
||||
**kwargs,
|
||||
) -> List[Any]:
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if ngpu >= 1:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
features_type: str = data_path_and_name_and_type[1]
|
||||
hop_length: int = 160
|
||||
sr: int = 16000
|
||||
if isinstance(fs, int):
|
||||
sr = fs
|
||||
else:
|
||||
if 'model_fs' in fs and fs['model_fs'] is not None:
|
||||
sr = fs['model_fs']
|
||||
if features_type != 'sound':
|
||||
frontend_conf = None
|
||||
if frontend_conf is not None:
|
||||
if 'hop_length' in frontend_conf:
|
||||
hop_length = frontend_conf['hop_length']
|
||||
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
if isinstance(audio_lists, bytes):
|
||||
file_count = 1
|
||||
else:
|
||||
file_count = len(audio_lists)
|
||||
if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
|
||||
mvn_file = data_path_and_name_and_type[2]
|
||||
mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
|
||||
frontend_conf['mvn_data'] = mvn_data
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
streaming=streaming,
|
||||
frontend_conf=frontend_conf,
|
||||
)
|
||||
speech2text = Speech2Text(**speech2text_kwargs)
|
||||
data_path_and_name_and_type_new = [
|
||||
audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
|
||||
]
|
||||
# 3. Build data-iterator
|
||||
loader = ASRTask.build_streaming_iterator_modelscope(
|
||||
data_path_and_name_and_type_new,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
sample_rate=fs
|
||||
)
|
||||
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
try:
|
||||
results = speech2text(**batch)
|
||||
except TooShortUttError as e:
|
||||
logging.warning(f"Utterance {keys} {e}")
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["<space>"], [2], hyp]] * nbest
|
||||
|
||||
# Only supporting batch_size==1
|
||||
key = keys[0]
|
||||
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
|
||||
if text is not None:
|
||||
text_postprocessed = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
|
||||
return asr_result_list
|
||||
|
||||
|
||||
|
||||
def set_parameters(language: str = None,
|
||||
sample_rate: Union[int, Dict[Any, int]] = None):
|
||||
if language is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = language
|
||||
if sample_rate is not None:
|
||||
global global_sample_rate
|
||||
global_sample_rate = sample_rate
|
||||
|
||||
|
||||
def asr_inference(maxlenratio: float,
|
||||
minlenratio: float,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
name_and_type: list,
|
||||
audio_lists: Union[List[Any], bytes],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
log_level: Union[int, str] = 'INFO',
|
||||
batch_size: int = 1,
|
||||
dtype: str = 'float32',
|
||||
seed: int = 0,
|
||||
key_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
word_lm_file: Optional[str] = None,
|
||||
ngram_file: Optional[str] = None,
|
||||
ngram_weight: float = 0.9,
|
||||
model_tag: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
transducer_conf: Optional[dict] = None,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None,
|
||||
fs: Union[dict, int] = None,
|
||||
lang: Optional[str] = None,
|
||||
outputdir: Optional[str] = None):
|
||||
if lang is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = lang
|
||||
if fs is not None:
|
||||
global global_sample_rate
|
||||
global_sample_rate = fs
|
||||
|
||||
# force use CPU if data type is bytes
|
||||
if isinstance(audio_lists, bytes):
|
||||
num_workers = 0
|
||||
ngpu = 0
|
||||
|
||||
return inference(output_dir=outputdir,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
batch_size=batch_size,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ngpu=ngpu,
|
||||
seed=seed,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
log_level=log_level,
|
||||
data_path_and_name_and_type=name_and_type,
|
||||
audio_lists=audio_lists,
|
||||
key_file=key_file,
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
word_lm_train_config=word_lm_train_config,
|
||||
word_lm_file=word_lm_file,
|
||||
ngram_file=ngram_file,
|
||||
model_tag=model_tag,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
transducer_conf=transducer_conf,
|
||||
streaming=streaming,
|
||||
frontend_conf=frontend_conf)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="ASR Decoding",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpuid_list",
|
||||
type=str,
|
||||
default="",
|
||||
help="The visible gpus",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=True,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--audio_lists", type=list,
|
||||
default=[{'key':'EdevDEWdIYQ_0021',
|
||||
'file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--asr_train_config",
|
||||
type=str,
|
||||
help="ASR training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--asr_model_file",
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_train_config",
|
||||
type=str,
|
||||
help="LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_file",
|
||||
type=str,
|
||||
help="LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_train_config",
|
||||
type=str,
|
||||
help="Word LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_file",
|
||||
type=str,
|
||||
help="Word LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ngram_file",
|
||||
type=str,
|
||||
help="N-gram parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Beam-search related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
||||
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
||||
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|
||||
group.add_argument(
|
||||
"--maxlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain max output length. "
|
||||
"If maxlenratio=0.0 (default), it uses a end-detect "
|
||||
"function "
|
||||
"to automatically find maximum hypothesis lengths."
|
||||
"If maxlenratio<0.0, its absolute value is interpreted"
|
||||
"as a constant max output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--minlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain min output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ctc_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="CTC weight in joint decoding",
|
||||
)
|
||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||
group.add_argument("--streaming", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("Text converter related")
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
choices=["char", "bpe", None],
|
||||
help="The token type for ASR model. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model path of sentencepiece. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
inference(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
686
funasr/bin/asr_inference_paraformer_modelscope.py
Executable file
686
funasr/bin/asr_inference_paraformer_modelscope.py
Executable file
@ -0,0 +1,686 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
|
||||
from funasr.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
||||
from funasr.tasks.lm import LMTask
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
|
||||
global_asr_language: str = 'zh-cn'
|
||||
global_sample_rate: Union[int, Dict[Any, int]] = {
|
||||
'audio_fs': 16000,
|
||||
'model_fs': 16000
|
||||
}
|
||||
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, device
|
||||
)
|
||||
if asr_model.frontend is None and frontend_conf is not None:
|
||||
frontend = WavFrontend(**frontend_conf)
|
||||
asr_model.frontend = frontend
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, device
|
||||
)
|
||||
scorers["lm"] = lm.lm
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
# transducer is not supported now
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
logging.info(f"Beam_search: {beam_search}")
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.beam_search = beam_search
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray]
|
||||
):
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
# data: (Nsamples,) -> (1, Nsamples)
|
||||
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
|
||||
# lengths: (1,)
|
||||
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
||||
batch = {"speech": speech, "speech_lengths": lengths}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
enc, enc_len = self.asr_model.encode(**batch)
|
||||
if isinstance(enc, tuple):
|
||||
enc = enc[0]
|
||||
assert len(enc) == 1, len(enc)
|
||||
|
||||
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
|
||||
pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
|
||||
pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device)
|
||||
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
nbest_hyps = self.beam_search(
|
||||
x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
|
||||
results.append((text, token, token_int, hyp, speech.size(1), lfr_factor))
|
||||
|
||||
# assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
def inference(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
dtype: str,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
seed: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
ngram_weight: float,
|
||||
penalty: float,
|
||||
nbest: int,
|
||||
num_workers: int,
|
||||
log_level: Union[int, str],
|
||||
data_path_and_name_and_type: list,
|
||||
audio_lists: Union[List[Any], bytes],
|
||||
key_file: Optional[str],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
lm_train_config: Optional[str],
|
||||
lm_file: Optional[str],
|
||||
word_lm_train_config: Optional[str],
|
||||
model_tag: Optional[str],
|
||||
token_type: Optional[str],
|
||||
bpemodel: Optional[str],
|
||||
output_dir: Optional[str],
|
||||
allow_variable_data_keys: bool,
|
||||
frontend_conf: dict = None,
|
||||
fs: Union[dict, int] = 16000,
|
||||
**kwargs,
|
||||
) -> List[Any]:
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if ngpu >= 1:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
# data_path_and_name_and_type = data_path_and_name_and_type[0]
|
||||
features_type: str = data_path_and_name_and_type[1]
|
||||
hop_length: int = 160
|
||||
sr: int = 16000
|
||||
if isinstance(fs, int):
|
||||
sr = fs
|
||||
else:
|
||||
if 'model_fs' in fs and fs['model_fs'] is not None:
|
||||
sr = fs['model_fs']
|
||||
if features_type != 'sound':
|
||||
frontend_conf = None
|
||||
if frontend_conf is not None:
|
||||
if 'hop_length' in frontend_conf:
|
||||
hop_length = frontend_conf['hop_length']
|
||||
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
if isinstance(audio_lists, bytes):
|
||||
file_count = 1
|
||||
else:
|
||||
file_count = len(audio_lists)
|
||||
if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
|
||||
mvn_file = data_path_and_name_and_type[2]
|
||||
mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
|
||||
frontend_conf['mvn_data'] = mvn_data
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
frontend_conf=frontend_conf,
|
||||
)
|
||||
speech2text = Speech2Text(**speech2text_kwargs)
|
||||
|
||||
data_path_and_name_and_type_new = [
|
||||
audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
|
||||
]
|
||||
|
||||
# 3. Build data-iterator
|
||||
loader = ASRTask.build_streaming_iterator_modelscope(
|
||||
data_path_and_name_and_type_new,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
sample_rate=fs
|
||||
)
|
||||
|
||||
forward_time_total = 0.0
|
||||
length_total = 0.0
|
||||
asr_result_list = []
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
|
||||
logging.info("decoding, utt_id: {}".format(keys))
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
|
||||
try:
|
||||
time_beg = time.time()
|
||||
results = speech2text(**batch)
|
||||
time_end = time.time()
|
||||
forward_time = time_end - time_beg
|
||||
lfr_factor = results[0][-1]
|
||||
length = results[0][-2]
|
||||
results = [results[0][:-2]]
|
||||
forward_time_total += forward_time
|
||||
length_total += length
|
||||
logging.info(
|
||||
"decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
|
||||
format(length, forward_time, 100 * forward_time / (length * lfr_factor)))
|
||||
except TooShortUttError as e:
|
||||
logging.warning(f"Utterance {keys} {e}")
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["<space>"], [2], hyp]] * nbest
|
||||
|
||||
# Only supporting batch_size==1
|
||||
key = keys[0]
|
||||
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
|
||||
if text is not None:
|
||||
text_postprocessed = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
|
||||
logging.info("decoding, predictions: {}".format(text))
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
|
||||
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
|
||||
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
|
||||
if features_type == 'sound':
|
||||
# data format is wav
|
||||
length_total_seconds = length_total / sr
|
||||
length_total_bytes = length_total * 2
|
||||
else:
|
||||
# data format is kaldi_ark
|
||||
length_total_seconds = length_total * hop_length / sr
|
||||
length_total_bytes = length_total * hop_length * 2
|
||||
|
||||
logger.info(
|
||||
header_colors + # noqa: *
|
||||
'decoding, feature length total: {}bytes, forward_time total: {:.4f}s, rtf avg: {:.4f}'
|
||||
.format(length_total_bytes, forward_time_total, forward_time_total /
|
||||
length_total_seconds) + end_colors)
|
||||
|
||||
return asr_result_list
|
||||
|
||||
|
||||
def set_parameters(language: str = None,
|
||||
sample_rate: Union[int, Dict[Any, int]] = None):
|
||||
if language is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = language
|
||||
if sample_rate is not None:
|
||||
global global_sample_rate
|
||||
global_sample_rate = sample_rate
|
||||
|
||||
|
||||
def asr_inference(maxlenratio: float,
|
||||
minlenratio: float,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
name_and_type: list,
|
||||
audio_lists: Union[List[Any], bytes],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
log_level: Union[int, str] = 'INFO',
|
||||
batch_size: int = 1,
|
||||
dtype: str = 'float32',
|
||||
seed: int = 0,
|
||||
key_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
word_lm_file: Optional[str] = None,
|
||||
ngram_file: Optional[str] = None,
|
||||
ngram_weight: float = 0.9,
|
||||
model_tag: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
transducer_conf: Optional[dict] = None,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None,
|
||||
fs: Union[dict, int] = None,
|
||||
lang: Optional[str] = None,
|
||||
outputdir: Optional[str] = None):
|
||||
if lang is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = lang
|
||||
if fs is not None:
|
||||
global global_sample_rate
|
||||
global_sample_rate = fs
|
||||
|
||||
# force use CPU if data type is bytes
|
||||
if isinstance(audio_lists, bytes):
|
||||
num_workers = 0
|
||||
ngpu = 0
|
||||
|
||||
return inference(output_dir=outputdir,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
batch_size=batch_size,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ngpu=ngpu,
|
||||
seed=seed,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
log_level=log_level,
|
||||
data_path_and_name_and_type=name_and_type,
|
||||
audio_lists=audio_lists,
|
||||
key_file=key_file,
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
word_lm_train_config=word_lm_train_config,
|
||||
word_lm_file=word_lm_file,
|
||||
ngram_file=ngram_file,
|
||||
model_tag=model_tag,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
transducer_conf=transducer_conf,
|
||||
streaming=streaming,
|
||||
frontend_conf=frontend_conf)
|
||||
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="ASR Decoding",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=True,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--audio_lists", type=list, default=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--asr_train_config",
|
||||
type=str,
|
||||
help="ASR training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--asr_model_file",
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_train_config",
|
||||
type=str,
|
||||
help="LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_file",
|
||||
type=str,
|
||||
help="LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_train_config",
|
||||
type=str,
|
||||
help="Word LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_file",
|
||||
type=str,
|
||||
help="Word LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ngram_file",
|
||||
type=str,
|
||||
help="N-gram parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Beam-search related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
||||
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
||||
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|
||||
group.add_argument(
|
||||
"--maxlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain max output length. "
|
||||
"If maxlenratio=0.0 (default), it uses a end-detect "
|
||||
"function "
|
||||
"to automatically find maximum hypothesis lengths."
|
||||
"If maxlenratio<0.0, its absolute value is interpreted"
|
||||
"as a constant max output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--minlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain min output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ctc_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="CTC weight in joint decoding",
|
||||
)
|
||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||
group.add_argument("--streaming", type=str2bool, default=False)
|
||||
|
||||
group.add_argument(
|
||||
"--asr_model_config",
|
||||
default=None,
|
||||
help="",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Text converter related")
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
choices=["char", "bpe", None],
|
||||
help="The token type for ASR model. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model path of sentencepiece. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
inference(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -15,6 +15,10 @@ if __name__ == '__main__':
|
||||
type=str,
|
||||
default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="model name in modelscope")
|
||||
parser.add_argument("--model_revision",
|
||||
type=str,
|
||||
default="v1.0.3",
|
||||
help="model revision in modelscope")
|
||||
parser.add_argument("--local_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -62,7 +66,8 @@ if __name__ == '__main__':
|
||||
if args.local_model_path is None:
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model="damo/{}".format(args.model_name))
|
||||
model="damo/{}".format(args.model_name),
|
||||
model_revision=args.model_revision)
|
||||
else:
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
|
||||
349
funasr/datasets/iterable_dataset_modelscope.py
Normal file
349
funasr/datasets/iterable_dataset_modelscope.py
Normal file
@ -0,0 +1,349 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Iterable dataset module."""
|
||||
import copy
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Collection, Dict, Iterator, Tuple, Union
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import torch
|
||||
from funasr.datasets.dataset import ESPnetDataset
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.utils import wav_utils
|
||||
|
||||
|
||||
def load_kaldi(input):
|
||||
retval = kaldiio.load_mat(input)
|
||||
if isinstance(retval, tuple):
|
||||
assert len(retval) == 2, len(retval)
|
||||
if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
|
||||
# sound scp case
|
||||
rate, array = retval
|
||||
elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
|
||||
# Extended ark format case
|
||||
array, rate = retval
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
|
||||
|
||||
# Multichannel wave fie
|
||||
# array: (NSample, Channel) or (Nsample)
|
||||
|
||||
else:
|
||||
# Normal ark case
|
||||
assert isinstance(retval, np.ndarray), type(retval)
|
||||
array = retval
|
||||
return array
|
||||
|
||||
|
||||
DATA_TYPES = {
|
||||
'sound':
|
||||
lambda x: soundfile.read(x)[0],
|
||||
'kaldi_ark':
|
||||
load_kaldi,
|
||||
'npy':
|
||||
np.load,
|
||||
'text_int':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
|
||||
'csv_int':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
|
||||
'text_float':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
|
||||
),
|
||||
'csv_float':
|
||||
lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
|
||||
),
|
||||
'text':
|
||||
lambda x: x,
|
||||
}
|
||||
|
||||
|
||||
class IterableESPnetDatasetModelScope(IterableDataset):
|
||||
"""Pytorch Dataset class for ESPNet.
|
||||
|
||||
Examples:
|
||||
>>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
|
||||
... ('token_int', 'output', 'text_int')],
|
||||
... )
|
||||
>>> for uid, data in dataset:
|
||||
... data
|
||||
{'input': per_utt_array, 'output': per_utt_array}
|
||||
"""
|
||||
def __init__(self,
|
||||
path_name_type_list: Collection[Tuple[any, str, str]],
|
||||
preprocess: Callable[[str, Dict[str, np.ndarray]],
|
||||
Dict[str, np.ndarray]] = None,
|
||||
float_dtype: str = 'float32',
|
||||
int_dtype: str = 'long',
|
||||
key_file: str = None,
|
||||
sample_rate: Union[dict, int] = 16000):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"')
|
||||
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.key_file = key_file
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.debug_info = {}
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
path_list = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
if name in self.debug_info:
|
||||
raise RuntimeError(f'"{name}" is duplicated for data-key')
|
||||
self.debug_info[name] = path_list, _type
|
||||
# for path, name, _type in path_name_type_list:
|
||||
for path in path_list:
|
||||
self.path_name_type_list.append((path, name, _type))
|
||||
|
||||
if len(non_iterable_list) != 0:
|
||||
# Some types doesn't support iterable mode
|
||||
self.non_iterable_dataset = ESPnetDataset(
|
||||
path_name_type_list=non_iterable_list,
|
||||
preprocess=preprocess,
|
||||
float_dtype=float_dtype,
|
||||
int_dtype=int_dtype,
|
||||
)
|
||||
else:
|
||||
self.non_iterable_dataset = None
|
||||
|
||||
self.apply_utt2category = False
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.debug_info
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.debug_info)
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += '('
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f'\n preprocess: {self.preprocess})'
|
||||
return _mes
|
||||
|
||||
def __iter__(
|
||||
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
|
||||
torch.set_printoptions(profile='default')
|
||||
count = len(self.path_name_type_list)
|
||||
for idx in range(count):
|
||||
# 2. Load the entry from each line and create a dict
|
||||
data = {}
|
||||
# 2.a. Load data streamingly
|
||||
|
||||
# value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
|
||||
value = self.path_name_type_list[idx][0]['file']
|
||||
uid = self.path_name_type_list[idx][0]['key']
|
||||
# name: speech
|
||||
name = self.path_name_type_list[idx][1]
|
||||
_type = self.path_name_type_list[idx][2]
|
||||
func = DATA_TYPES[_type]
|
||||
array = func(value)
|
||||
|
||||
# 2.b. audio resample
|
||||
if _type == 'sound':
|
||||
audio_sr: int = 16000
|
||||
model_sr: int = 16000
|
||||
if isinstance(self.sample_rate, int):
|
||||
model_sr = self.sample_rate
|
||||
else:
|
||||
if 'audio_sr' in self.sample_rate:
|
||||
audio_sr = self.sample_rate['audio_sr']
|
||||
if 'model_sr' in self.sample_rate:
|
||||
model_sr = self.sample_rate['model_sr']
|
||||
array = wav_utils.torch_resample(array, audio_sr, model_sr)
|
||||
|
||||
# array: [ 1.25122070e-03 ... ]
|
||||
data[name] = array
|
||||
|
||||
# 3. [Option] Apply preprocessing
|
||||
# e.g. espnet2.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
|
||||
|
||||
# 4. Force data-precision
|
||||
for name in data:
|
||||
# value is np.ndarray data
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f'All values must be converted to np.ndarray object '
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.'
|
||||
)
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == 'f':
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == 'i':
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported dtype: {value.dtype}')
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
|
||||
if count == 0:
|
||||
raise RuntimeError('No iteration')
|
||||
|
||||
|
||||
class IterableESPnetBytesModelScope(IterableDataset):
|
||||
"""Pytorch audio bytes class for ESPNet.
|
||||
|
||||
Examples:
|
||||
>>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
|
||||
... ('token_int', 'output', 'text_int')],
|
||||
... )
|
||||
>>> for uid, data in dataset:
|
||||
... data
|
||||
{'input': per_utt_array, 'output': per_utt_array}
|
||||
"""
|
||||
def __init__(self,
|
||||
path_name_type_list: Collection[Tuple[any, str, str]],
|
||||
preprocess: Callable[[str, Dict[str, np.ndarray]],
|
||||
Dict[str, np.ndarray]] = None,
|
||||
float_dtype: str = 'float32',
|
||||
int_dtype: str = 'long',
|
||||
key_file: str = None,
|
||||
sample_rate: Union[dict, int] = 16000):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"')
|
||||
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.float_dtype = float_dtype
|
||||
self.int_dtype = int_dtype
|
||||
self.key_file = key_file
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.debug_info = {}
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
audio_data = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
if name in self.debug_info:
|
||||
raise RuntimeError(f'"{name}" is duplicated for data-key')
|
||||
self.debug_info[name] = audio_data, _type
|
||||
self.path_name_type_list.append((audio_data, name, _type))
|
||||
|
||||
if len(non_iterable_list) != 0:
|
||||
# Some types doesn't support iterable mode
|
||||
self.non_iterable_dataset = ESPnetDataset(
|
||||
path_name_type_list=non_iterable_list,
|
||||
preprocess=preprocess,
|
||||
float_dtype=float_dtype,
|
||||
int_dtype=int_dtype,
|
||||
)
|
||||
else:
|
||||
self.non_iterable_dataset = None
|
||||
|
||||
self.apply_utt2category = False
|
||||
|
||||
if float_dtype == 'float32':
|
||||
self.np_dtype = np.float32
|
||||
|
||||
def has_name(self, name) -> bool:
|
||||
return name in self.debug_info
|
||||
|
||||
def names(self) -> Tuple[str, ...]:
|
||||
return tuple(self.debug_info)
|
||||
|
||||
def __repr__(self):
|
||||
_mes = self.__class__.__name__
|
||||
_mes += '('
|
||||
for name, (path, _type) in self.debug_info.items():
|
||||
_mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
|
||||
_mes += f'\n preprocess: {self.preprocess})'
|
||||
return _mes
|
||||
|
||||
def __iter__(
|
||||
self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
|
||||
|
||||
torch.set_printoptions(profile='default')
|
||||
# 2. Load the entry from each line and create a dict
|
||||
data = {}
|
||||
# 2.a. Load data streamingly
|
||||
|
||||
value = self.path_name_type_list[0][0]
|
||||
uid = 'pcm_data'
|
||||
# name: speech
|
||||
name = self.path_name_type_list[0][1]
|
||||
_type = self.path_name_type_list[0][2]
|
||||
func = DATA_TYPES[_type]
|
||||
# array: [ 1.25122070e-03 ... ]
|
||||
# data[name] = np.frombuffer(value, dtype=self.np_dtype)
|
||||
|
||||
# 2.b. byte(PCM16) to float32
|
||||
middle_data = np.frombuffer(value, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2**(i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
|
||||
dtype=self.np_dtype)
|
||||
|
||||
# 2.c. audio resample
|
||||
if _type == 'sound':
|
||||
audio_sr: int = 16000
|
||||
model_sr: int = 16000
|
||||
if isinstance(self.sample_rate, int):
|
||||
model_sr = self.sample_rate
|
||||
else:
|
||||
if 'audio_sr' in self.sample_rate:
|
||||
audio_sr = self.sample_rate['audio_sr']
|
||||
if 'model_sr' in self.sample_rate:
|
||||
model_sr = self.sample_rate['model_sr']
|
||||
array = wav_utils.torch_resample(array, audio_sr, model_sr)
|
||||
|
||||
data[name] = array
|
||||
|
||||
# 3. [Option] Apply preprocessing
|
||||
# e.g. espnet2.train.preprocessor:CommonPreprocessor
|
||||
if self.preprocess is not None:
|
||||
data = self.preprocess(uid, data)
|
||||
# data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
|
||||
|
||||
# 4. Force data-precision
|
||||
for name in data:
|
||||
# value is np.ndarray data
|
||||
value = data[name]
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise RuntimeError(
|
||||
f'All values must be converted to np.ndarray object '
|
||||
f'by preprocessing, but "{name}" is still {type(value)}.')
|
||||
|
||||
# Cast to desired type
|
||||
if value.dtype.kind == 'f':
|
||||
value = value.astype(self.float_dtype)
|
||||
elif value.dtype.kind == 'i':
|
||||
value = value.astype(self.int_dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Not supported dtype: {value.dtype}')
|
||||
data[name] = value
|
||||
|
||||
yield uid, data
|
||||
@ -330,9 +330,10 @@ class Paraformer(AbsESPnetModel):
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
||||
|
||||
decoder_out, _ = self.decoder(
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
|
||||
)
|
||||
decoder_out = decoder_outs[0]
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
|
||||
@ -553,7 +554,6 @@ class ParaformerBert(Paraformer):
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
joint_network: Optional[torch.nn.Module],
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
@ -590,7 +590,6 @@ class ParaformerBert(Paraformer):
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
joint_network=joint_network,
|
||||
ctc_weight=ctc_weight,
|
||||
interctc_weight=interctc_weight,
|
||||
ignore_id=ignore_id,
|
||||
|
||||
155
funasr/models/frontend/wav_frontend.py
Normal file
155
funasr/models/frontend/wav_frontend.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
|
||||
import copy
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.layers.log_mel import LogMel
|
||||
from funasr.layers.stft import Stft
|
||||
from funasr.utils.get_default_kwargs import get_default_kwargs
|
||||
from funasr.modules.frontends.frontend import Frontend
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
def apply_cmvn(inputs, mvn): # noqa
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
|
||||
device = inputs.device
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
meams = np.tile(mvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(mvn[1:2, :dim], (frame, 1))
|
||||
inputs += torch.from_numpy(meams).type(dtype).to(device)
|
||||
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||
|
||||
return inputs.type(torch.float32)
|
||||
|
||||
|
||||
def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / lfr_n))
|
||||
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
||||
inputs = torch.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||
else: # process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n:]).view(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = torch.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
|
||||
|
||||
class WavFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = 400,
|
||||
hop_length: int = 160,
|
||||
window: Optional[str] = 'hamming',
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
htk: bool = False,
|
||||
mvn_data=None,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.fs = fs
|
||||
self.mvn_data = mvn_data
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
else:
|
||||
self.stft = None
|
||||
self.apply_stft = apply_stft
|
||||
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = 'default'
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
sample_frequency = self.fs
|
||||
num_mel_bins = self.n_mels
|
||||
frame_length = self.win_length * 1000 / sample_frequency
|
||||
frame_shift = self.hop_length * 1000 / sample_frequency
|
||||
|
||||
waveform = input * (1 << 15)
|
||||
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=1.0,
|
||||
energy_floor=0.0,
|
||||
window_type=self.window,
|
||||
sample_frequency=sample_frequency)
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.mvn_data is not None:
|
||||
mat = apply_cmvn(mat, self.mvn_data)
|
||||
|
||||
input_feats = mat[None, :]
|
||||
feats_lens = torch.randn(1)
|
||||
feats_lens.fill_(input_feats.shape[1])
|
||||
|
||||
return input_feats, feats_lens
|
||||
@ -4,7 +4,7 @@ from torch import nn
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
class CifPredictor(nn.Module):
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
||||
super(CifPredictor, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
|
||||
@ -38,6 +38,7 @@ from funasr.datasets.dataset import AbsDataset
|
||||
from funasr.datasets.dataset import DATA_TYPES
|
||||
from funasr.datasets.dataset import ESPnetDataset
|
||||
from funasr.datasets.iterable_dataset import IterableESPnetDataset
|
||||
from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
|
||||
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
|
||||
@ -1026,7 +1027,7 @@ class AbsTask(ABC):
|
||||
@classmethod
|
||||
def check_task_requirements(
|
||||
cls,
|
||||
dataset: Union[AbsDataset, IterableESPnetDataset],
|
||||
dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
|
||||
allow_variable_data_keys: bool,
|
||||
train: bool,
|
||||
inference: bool = False,
|
||||
@ -1748,6 +1749,64 @@ class AbsTask(ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_streaming_iterator_modelscope(
|
||||
cls,
|
||||
data_path_and_name_and_type,
|
||||
preprocess_fn,
|
||||
collate_fn,
|
||||
key_file: str = None,
|
||||
batch_size: int = 1,
|
||||
dtype: str = np.float32,
|
||||
num_workers: int = 1,
|
||||
allow_variable_data_keys: bool = False,
|
||||
ngpu: int = 0,
|
||||
inference: bool = False,
|
||||
sample_rate: Union[dict, int] = 16000
|
||||
) -> DataLoader:
|
||||
"""Build DataLoader using iterable dataset"""
|
||||
assert check_argument_types()
|
||||
# For backward compatibility for pytorch DataLoader
|
||||
if collate_fn is not None:
|
||||
kwargs = dict(collate_fn=collate_fn)
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
audio_data = data_path_and_name_and_type[0]
|
||||
if isinstance(audio_data, bytes):
|
||||
dataset = IterableESPnetBytesModelScope(
|
||||
data_path_and_name_and_type,
|
||||
float_dtype=dtype,
|
||||
preprocess=preprocess_fn,
|
||||
key_file=key_file,
|
||||
sample_rate=sample_rate
|
||||
)
|
||||
else:
|
||||
dataset = IterableESPnetDatasetModelScope(
|
||||
data_path_and_name_and_type,
|
||||
float_dtype=dtype,
|
||||
preprocess=preprocess_fn,
|
||||
key_file=key_file,
|
||||
sample_rate=sample_rate
|
||||
)
|
||||
|
||||
if dataset.apply_utt2category:
|
||||
kwargs.update(batch_size=1)
|
||||
else:
|
||||
kwargs.update(batch_size=batch_size)
|
||||
|
||||
cls.check_task_requirements(dataset,
|
||||
allow_variable_data_keys,
|
||||
train=False,
|
||||
inference=inference)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
pin_memory=ngpu > 0,
|
||||
num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
||||
@classmethod
|
||||
def build_model_from_file(
|
||||
|
||||
85
funasr/utils/asr_env_checking.py
Normal file
85
funasr/utils/asr_env_checking.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import ssl
|
||||
|
||||
import nltk
|
||||
|
||||
# mkdir nltk_data dir if not exist
|
||||
try:
|
||||
nltk.data.find('.')
|
||||
except LookupError:
|
||||
dir_list = nltk.data.path
|
||||
for dir_item in dir_list:
|
||||
if not os.path.exists(dir_item):
|
||||
os.mkdir(dir_item)
|
||||
if os.path.exists(dir_item):
|
||||
break
|
||||
|
||||
# download one package if nltk_data not exist
|
||||
try:
|
||||
nltk.data.find('.')
|
||||
except: # noqa: *
|
||||
try:
|
||||
_create_unverified_https_context = ssl._create_unverified_context
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
ssl._create_default_https_context = _create_unverified_https_context
|
||||
|
||||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
|
||||
|
||||
# deploy taggers/averaged_perceptron_tagger
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except: # noqa: *
|
||||
data_dir = nltk.data.find('.')
|
||||
target_dir = os.path.join(data_dir, 'taggers')
|
||||
if not os.path.exists(target_dir):
|
||||
os.mkdir(target_dir)
|
||||
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
|
||||
'averaged_perceptron_tagger.zip')
|
||||
shutil.copyfile(src_file,
|
||||
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'))
|
||||
shutil._unpack_zipfile(
|
||||
os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir)
|
||||
|
||||
# deploy corpora/cmudict
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict')
|
||||
except: # noqa: *
|
||||
data_dir = nltk.data.find('.')
|
||||
target_dir = os.path.join(data_dir, 'corpora')
|
||||
if not os.path.exists(target_dir):
|
||||
os.mkdir(target_dir)
|
||||
src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
|
||||
'cmudict.zip')
|
||||
shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip'))
|
||||
shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except: # noqa: *
|
||||
try:
|
||||
_create_unverified_https_context = ssl._create_unverified_context
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
ssl._create_default_https_context = _create_unverified_https_context
|
||||
|
||||
nltk.download('averaged_perceptron_tagger',
|
||||
halt_on_error=False,
|
||||
raise_on_error=True)
|
||||
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict')
|
||||
except: # noqa: *
|
||||
try:
|
||||
_create_unverified_https_context = ssl._create_unverified_context
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
ssl._create_default_https_context = _create_unverified_https_context
|
||||
|
||||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
|
||||
327
funasr/utils/asr_utils.py
Normal file
327
funasr/utils/asr_utils.py
Normal file
@ -0,0 +1,327 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import struct
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
green_color = '\033[1;32m'
|
||||
red_color = '\033[0;31;40m'
|
||||
yellow_color = '\033[0;33;40m'
|
||||
end_color = '\033[0m'
|
||||
|
||||
global_asr_language = 'zh-cn'
|
||||
|
||||
|
||||
def get_version():
|
||||
return float(pkg_resources.get_distribution('easyasr').version)
|
||||
|
||||
|
||||
def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
|
||||
r_audio_fs = None
|
||||
|
||||
if audio_format == 'wav':
|
||||
r_audio_fs = get_sr_from_wav(audio_in)
|
||||
elif audio_format == 'pcm' and isinstance(audio_in, bytes):
|
||||
r_audio_fs = get_sr_from_bytes(audio_in)
|
||||
|
||||
return r_audio_fs
|
||||
|
||||
|
||||
def type_checking(audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None):
|
||||
r_recog_type = recog_type
|
||||
r_audio_format = audio_format
|
||||
r_wav_path = audio_in
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
|
||||
elif isinstance(audio_in, bytes):
|
||||
assert len(audio_in) > 0, 'audio in is empty'
|
||||
r_audio_format = 'pcm'
|
||||
r_recog_type = 'wav'
|
||||
|
||||
if r_recog_type is None:
|
||||
# audio_in is wav, recog_type is wav_file
|
||||
if os.path.isfile(audio_in):
|
||||
if audio_in.endswith('.wav') or audio_in.endswith('.WAV'):
|
||||
r_recog_type = 'wav'
|
||||
r_audio_format = 'wav'
|
||||
|
||||
# recog_type is datasets_file
|
||||
elif os.path.isdir(audio_in):
|
||||
dir_name = os.path.basename(audio_in)
|
||||
if 'test' in dir_name:
|
||||
r_recog_type = 'test'
|
||||
elif 'dev' in dir_name:
|
||||
r_recog_type = 'dev'
|
||||
elif 'train' in dir_name:
|
||||
r_recog_type = 'train'
|
||||
|
||||
if r_audio_format is None:
|
||||
if find_file_by_ends(audio_in, '.ark'):
|
||||
r_audio_format = 'kaldi_ark'
|
||||
elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
|
||||
audio_in, '.WAV'):
|
||||
r_audio_format = 'wav'
|
||||
elif find_file_by_ends(audio_in, '.records'):
|
||||
r_audio_format = 'tfrecord'
|
||||
|
||||
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
|
||||
# datasets with kaldi_ark file
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
|
||||
elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
|
||||
# datasets with tensorflow records file
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
|
||||
elif r_audio_format == 'wav' and r_recog_type != 'wav':
|
||||
# datasets with waveform files
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
|
||||
|
||||
return r_recog_type, r_audio_format, r_wav_path
|
||||
|
||||
|
||||
def get_sr_from_bytes(wav: bytes):
|
||||
sr = None
|
||||
data = wav
|
||||
if len(data) > 44:
|
||||
try:
|
||||
header_fields = {}
|
||||
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
|
||||
header_fields['Format'] = str(data[8:12], 'UTF-8')
|
||||
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
|
||||
if header_fields['ChunkID'] == 'RIFF' and header_fields[
|
||||
'Format'] == 'WAVE' and header_fields[
|
||||
'Subchunk1ID'] == 'fmt ':
|
||||
header_fields['SampleRate'] = struct.unpack('<I',
|
||||
data[24:28])[0]
|
||||
sr = header_fields['SampleRate']
|
||||
except Exception:
|
||||
# no treatment
|
||||
pass
|
||||
else:
|
||||
logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
|
||||
|
||||
return sr
|
||||
|
||||
|
||||
def get_sr_from_wav(fname: str):
|
||||
fs = None
|
||||
if os.path.isfile(fname):
|
||||
audio, fs = librosa.load(fname, sr=None)
|
||||
return fs
|
||||
elif os.path.isdir(fname):
|
||||
dir_files = os.listdir(fname)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(fname, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
|
||||
fs = get_sr_from_wav(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
fs = get_sr_from_wav(file_path)
|
||||
|
||||
if fs is not None:
|
||||
break
|
||||
|
||||
return fs
|
||||
|
||||
|
||||
def find_file_by_ends(dir_path: str, ends: str):
|
||||
dir_files = os.listdir(dir_path)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith(ends):
|
||||
return True
|
||||
elif os.path.isdir(file_path):
|
||||
if find_file_by_ends(file_path, ends):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
|
||||
dir_files = os.listdir(dir_path)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
|
||||
wav_list.append(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
recursion_dir_all_wav(wav_list, file_path)
|
||||
|
||||
return wav_list
|
||||
|
||||
|
||||
def set_parameters(language: str = None):
|
||||
if language is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = language
|
||||
|
||||
|
||||
def compute_wer(hyp_list: List[Any],
|
||||
ref_list: List[Any],
|
||||
lang: str = None) -> Dict[str, Any]:
|
||||
assert len(hyp_list) > 0, 'hyp list is empty'
|
||||
assert len(ref_list) > 0, 'ref list is empty'
|
||||
|
||||
if lang is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = lang
|
||||
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
}
|
||||
|
||||
for h_item in hyp_list:
|
||||
for r_item in ref_list:
|
||||
if h_item['key'] == r_item['key']:
|
||||
out_item = compute_wer_by_line(h_item['value'],
|
||||
r_item['value'],
|
||||
global_asr_language)
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
print_wrong_sentence(key=h_item['key'],
|
||||
hyp=h_item['value'],
|
||||
ref=r_item['value'])
|
||||
else:
|
||||
print_correct_sentence(key=h_item['key'],
|
||||
hyp=h_item['value'],
|
||||
ref=r_item['value'])
|
||||
|
||||
break
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
|
||||
return rst
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp: List[str],
|
||||
ref: List[str],
|
||||
lang: str = 'zh-cn') -> Dict[str, Any]:
|
||||
if lang != 'zh-cn':
|
||||
hyp = hyp.split()
|
||||
ref = ref.split()
|
||||
|
||||
hyp = list(map(lambda x: x.lower(), hyp))
|
||||
ref = list(map(lambda x: x.lower(), ref))
|
||||
|
||||
len_hyp = len(hyp)
|
||||
len_ref = len(ref)
|
||||
|
||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
|
||||
|
||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
|
||||
|
||||
for i in range(len_hyp + 1):
|
||||
cost_matrix[i][0] = i
|
||||
for j in range(len_ref + 1):
|
||||
cost_matrix[0][j] = j
|
||||
|
||||
for i in range(1, len_hyp + 1):
|
||||
for j in range(1, len_ref + 1):
|
||||
if hyp[i - 1] == ref[j - 1]:
|
||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
|
||||
else:
|
||||
substitution = cost_matrix[i - 1][j - 1] + 1
|
||||
insertion = cost_matrix[i - 1][j] + 1
|
||||
deletion = cost_matrix[i][j - 1] + 1
|
||||
|
||||
compare_val = [substitution, insertion, deletion]
|
||||
|
||||
min_val = min(compare_val)
|
||||
operation_idx = compare_val.index(min_val) + 1
|
||||
cost_matrix[i][j] = min_val
|
||||
ops_matrix[i][j] = operation_idx
|
||||
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_ref,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
|
||||
return rst
|
||||
|
||||
|
||||
def print_wrong_sentence(key: str, hyp: str, ref: str):
|
||||
space = len(key)
|
||||
print(key + yellow_color + ' ref: ' + ref)
|
||||
print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
|
||||
|
||||
|
||||
def print_correct_sentence(key: str, hyp: str, ref: str):
|
||||
space = len(key)
|
||||
print(key + yellow_color + ' ref: ' + ref)
|
||||
print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
|
||||
|
||||
|
||||
def print_progress(percent):
|
||||
if percent > 1:
|
||||
percent = 1
|
||||
res = int(50 * percent) * '#'
|
||||
print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')
|
||||
174
funasr/utils/postprocess_utils.py
Normal file
174
funasr/utils/postprocess_utils.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import string
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
def isChinese(ch: str):
|
||||
if '\u4e00' <= ch <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def isAllChinese(word: Union[List[Any], str]):
|
||||
word_lists = []
|
||||
table = str.maketrans('', '', string.punctuation)
|
||||
for i in word:
|
||||
cur = i.translate(table)
|
||||
cur = cur.replace(' ', '')
|
||||
cur = cur.replace('</s>', '')
|
||||
cur = cur.replace('<s>', '')
|
||||
word_lists.append(cur)
|
||||
|
||||
if len(word_lists) == 0:
|
||||
return False
|
||||
|
||||
for ch in word_lists:
|
||||
if isChinese(ch) is False:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def isAllAlpha(word: Union[List[Any], str]):
|
||||
word_lists = []
|
||||
table = str.maketrans('', '', string.punctuation)
|
||||
for i in word:
|
||||
cur = i.translate(table)
|
||||
cur = cur.replace(' ', '')
|
||||
cur = cur.replace('</s>', '')
|
||||
cur = cur.replace('<s>', '')
|
||||
word_lists.append(cur)
|
||||
|
||||
if len(word_lists) == 0:
|
||||
return False
|
||||
|
||||
for ch in word_lists:
|
||||
if ch.isalpha() is False:
|
||||
return False
|
||||
elif ch.isalpha() is True and isChinese(ch) is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def abbr_dispose(words: List[Any]) -> List[Any]:
|
||||
words_size = len(words)
|
||||
word_lists = []
|
||||
abbr_begin = []
|
||||
abbr_end = []
|
||||
last_num = -1
|
||||
for num in range(words_size):
|
||||
if num <= last_num:
|
||||
continue
|
||||
|
||||
if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
|
||||
if num + 1 < words_size and words[
|
||||
num + 1] == ' ' and num + 2 < words_size and len(
|
||||
words[num +
|
||||
2]) == 1 and words[num +
|
||||
2].encode('utf-8').isalpha():
|
||||
# found the begin of abbr
|
||||
abbr_begin.append(num)
|
||||
num += 2
|
||||
abbr_end.append(num)
|
||||
# to find the end of abbr
|
||||
while True:
|
||||
num += 1
|
||||
if num < words_size and words[num] == ' ':
|
||||
num += 1
|
||||
if num < words_size and len(
|
||||
words[num]) == 1 and words[num].encode(
|
||||
'utf-8').isalpha():
|
||||
abbr_end.pop()
|
||||
abbr_end.append(num)
|
||||
last_num = num
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
last_num = -1
|
||||
for num in range(words_size):
|
||||
if num <= last_num:
|
||||
continue
|
||||
|
||||
if num in abbr_begin:
|
||||
word_lists.append(words[num].upper())
|
||||
num += 1
|
||||
while num < words_size:
|
||||
if num in abbr_end:
|
||||
word_lists.append(words[num].upper())
|
||||
last_num = num
|
||||
break
|
||||
else:
|
||||
if words[num].encode('utf-8').isalpha():
|
||||
word_lists.append(words[num].upper())
|
||||
num += 1
|
||||
else:
|
||||
word_lists.append(words[num])
|
||||
|
||||
return word_lists
|
||||
|
||||
|
||||
def sentence_postprocess(words: List[Any]):
|
||||
middle_lists = []
|
||||
word_lists = []
|
||||
word_item = ''
|
||||
|
||||
# wash words lists
|
||||
for i in words:
|
||||
word = ''
|
||||
if isinstance(i, str):
|
||||
word = i
|
||||
else:
|
||||
word = i.decode('utf-8')
|
||||
|
||||
if word in ['<s>', '</s>', '<unk>']:
|
||||
continue
|
||||
else:
|
||||
middle_lists.append(word)
|
||||
|
||||
# all chinese characters
|
||||
if isAllChinese(middle_lists):
|
||||
for ch in middle_lists:
|
||||
word_lists.append(ch.replace(' ', ''))
|
||||
|
||||
# all alpha characters
|
||||
elif isAllAlpha(middle_lists):
|
||||
for ch in middle_lists:
|
||||
word = ''
|
||||
if '@@' in ch:
|
||||
word = ch.replace('@@', '')
|
||||
word_item += word
|
||||
else:
|
||||
word_item += ch
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(' ')
|
||||
word_item = ''
|
||||
|
||||
# mix characters
|
||||
else:
|
||||
alpha_blank = False
|
||||
for ch in middle_lists:
|
||||
word = ''
|
||||
if isAllChinese(ch):
|
||||
if alpha_blank is True:
|
||||
word_lists.pop()
|
||||
word_lists.append(ch)
|
||||
alpha_blank = False
|
||||
elif '@@' in ch:
|
||||
word = ch.replace('@@', '')
|
||||
word_item += word
|
||||
alpha_blank = False
|
||||
elif isAllAlpha(ch):
|
||||
word_item += ch
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(' ')
|
||||
word_item = ''
|
||||
alpha_blank = True
|
||||
else:
|
||||
raise ValueError('invalid character: {}'.format(ch))
|
||||
|
||||
word_lists = abbr_dispose(word_lists)
|
||||
sentence = ''.join(word_lists).strip()
|
||||
return sentence
|
||||
178
funasr/utils/wav_utils.py
Normal file
178
funasr/utils/wav_utils.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import kaldiio
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
|
||||
def ndarray_resample(audio_in: np.ndarray,
|
||||
fs_in: int = 16000,
|
||||
fs_out: int = 16000) -> np.ndarray:
|
||||
audio_out = audio_in
|
||||
if fs_in != fs_out:
|
||||
audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
|
||||
return audio_out
|
||||
|
||||
|
||||
def torch_resample(audio_in: torch.Tensor,
|
||||
fs_in: int = 16000,
|
||||
fs_out: int = 16000) -> torch.Tensor:
|
||||
audio_out = audio_in
|
||||
if fs_in != fs_out:
|
||||
audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
|
||||
new_freq=fs_out)(audio_in)
|
||||
return audio_out
|
||||
|
||||
|
||||
def extract_CMVN_featrures(mvn_file):
|
||||
"""
|
||||
extract CMVN from cmvn.ark
|
||||
"""
|
||||
|
||||
if not os.path.exists(mvn_file):
|
||||
return None
|
||||
try:
|
||||
cmvn = kaldiio.load_mat(mvn_file)
|
||||
means = []
|
||||
variance = []
|
||||
|
||||
for i in range(cmvn.shape[1] - 1):
|
||||
means.append(float(cmvn[0][i]))
|
||||
|
||||
count = float(cmvn[0][-1])
|
||||
|
||||
for i in range(cmvn.shape[1] - 1):
|
||||
variance.append(float(cmvn[1][i]))
|
||||
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
except Exception:
|
||||
cmvn = extract_CMVN_features_txt(mvn_file)
|
||||
return cmvn
|
||||
|
||||
|
||||
def extract_CMVN_features_txt(mvn_file): # noqa
|
||||
with open(mvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
add_shift_list = []
|
||||
rescale_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
add_shift_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
rescale_list = list(rescale_line)
|
||||
continue
|
||||
add_shift_list_f = [float(s) for s in add_shift_list]
|
||||
rescale_list_f = [float(s) for s in rescale_list]
|
||||
cmvn = np.array([add_shift_list_f, rescale_list_f])
|
||||
return cmvn
|
||||
|
||||
|
||||
def build_LFR_features(inputs, m=7, n=6): # noqa
|
||||
"""
|
||||
Actually, this implements stacking frames and skipping frames.
|
||||
if m = 1 and n = 1, just return the origin features.
|
||||
if m = 1 and n > 1, it works like skipping.
|
||||
if m > 1 and n = 1, it works like stacking but only support right frames.
|
||||
if m > 1 and n > 1, it works like LFR.
|
||||
|
||||
Args:
|
||||
inputs_batch: inputs is T x D np.ndarray
|
||||
m: number of frames to stack
|
||||
n: number of frames to skip
|
||||
"""
|
||||
# LFR_inputs_batch = []
|
||||
# for inputs in inputs_batch:
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / n))
|
||||
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
|
||||
inputs = np.vstack((left_padding, inputs))
|
||||
T = T + (m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if m <= T - i * n:
|
||||
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
|
||||
else: # process last LFR frame
|
||||
num_padding = m - (T - i * n)
|
||||
frame = np.hstack(inputs[i * n:])
|
||||
for _ in range(num_padding):
|
||||
frame = np.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
return np.vstack(LFR_inputs)
|
||||
|
||||
|
||||
def compute_fbank(wav_file,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
is_pcm=False,
|
||||
fs: Union[int, Dict[Any, int]] = 16000):
|
||||
audio_sr: int = 16000
|
||||
model_sr: int = 16000
|
||||
if isinstance(fs, int):
|
||||
model_sr = fs
|
||||
audio_sr = fs
|
||||
else:
|
||||
model_sr = fs['model_fs']
|
||||
audio_sr = fs['audio_fs']
|
||||
|
||||
if is_pcm is True:
|
||||
# byte(PCM16) to float32, and resample
|
||||
value = wav_file
|
||||
middle_data = np.frombuffer(value, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in 'iu':
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype('float32')
|
||||
if dtype.kind != 'f':
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2**(i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
waveform = np.frombuffer(
|
||||
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
waveform = ndarray_resample(waveform, audio_sr, model_sr)
|
||||
waveform = torch.from_numpy(waveform.reshape(1, -1))
|
||||
else:
|
||||
# load pcm from wav, and resample
|
||||
waveform, audio_sr = torchaudio.load(wav_file)
|
||||
waveform = waveform * (1 << 15)
|
||||
waveform = torch_resample(waveform, audio_sr, model_sr)
|
||||
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
window_type='hamming',
|
||||
sample_frequency=model_sr)
|
||||
|
||||
input_feats = mat
|
||||
|
||||
return input_feats
|
||||
@ -1 +1 @@
|
||||
0.1.0
|
||||
0.1.3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user