mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
inference
This commit is contained in:
parent
94a4dbba3d
commit
0271fbe4fd
@ -32,6 +32,7 @@ from funasr.modules.beam_search.beam_search import BeamSearch
|
||||
from funasr.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
|
||||
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
|
||||
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
@ -58,7 +59,7 @@ from funasr.bin.punc_infer import Text2Punc
|
||||
from funasr.utils.vad_utils import slice_padding_fbank
|
||||
from funasr.tasks.vad import VADTask
|
||||
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
|
||||
|
||||
from funasr.tasks.asr import frontend_choices
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
@ -1599,3 +1600,251 @@ class Speech2TextTransducer:
|
||||
|
||||
return Speech2Text(**kwargs)
|
||||
|
||||
|
||||
class Speech2TextSAASR:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
if asr_train_args.frontend == 'wav_frontend':
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
else:
|
||||
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
decoder = asr_model.decoder
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, None, device
|
||||
)
|
||||
scorers["lm"] = lm.lm
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
# transducer is not supported now
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.beam_search = beam_search
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
self.frontend = frontend
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
|
||||
profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
|
||||
) -> List[
|
||||
Tuple[
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
List[str],
|
||||
List[int],
|
||||
Union[HypothesisSAASR],
|
||||
]
|
||||
]:
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, text_id, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
if isinstance(profile, np.ndarray):
|
||||
profile = torch.tensor(profile)
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
self.asr_model.frontend = None
|
||||
else:
|
||||
feats = speech
|
||||
feats_len = speech_lengths
|
||||
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
asr_enc, _, spk_enc = self.asr_model.encode(**batch)
|
||||
if isinstance(asr_enc, tuple):
|
||||
asr_enc = asr_enc[0]
|
||||
if isinstance(spk_enc, tuple):
|
||||
spk_enc = spk_enc[0]
|
||||
assert len(asr_enc) == 1, len(asr_enc)
|
||||
assert len(spk_enc) == 1, len(spk_enc)
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1: last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1: last_pos].tolist()
|
||||
|
||||
spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
|
||||
|
||||
token_ori = self.converter.ids2tokens(token_int)
|
||||
text_ori = self.tokenizer.tokens2text(token_ori)
|
||||
|
||||
text_ori_spklist = text_ori.split('$')
|
||||
cur_index = 0
|
||||
spk_choose = []
|
||||
for i in range(len(text_ori_spklist)):
|
||||
text_ori_split = text_ori_spklist[i]
|
||||
n = len(text_ori_split)
|
||||
spk_weights_local = spk_weigths[cur_index: cur_index + n]
|
||||
cur_index = cur_index + n + 1
|
||||
spk_weights_local = spk_weights_local.mean(dim=0)
|
||||
spk_choose_local = spk_weights_local.argmax(-1)
|
||||
spk_choose.append(spk_choose_local.item() + 1)
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
|
||||
text_spklist = text.split('$')
|
||||
assert len(spk_choose) == len(text_spklist)
|
||||
|
||||
spk_list = []
|
||||
for i in range(len(text_spklist)):
|
||||
text_split = text_spklist[i]
|
||||
n = len(text_split)
|
||||
spk_list.append(str(spk_choose[i]) * n)
|
||||
|
||||
text_id = '$'.join(spk_list)
|
||||
|
||||
assert len(text) == len(text_id)
|
||||
|
||||
results.append((text, text_id, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
@ -77,6 +77,7 @@ from funasr.bin.vad_infer import Speech2VadSegment
|
||||
from funasr.bin.punc_infer import Text2Punc
|
||||
from funasr.bin.tp_infer import Speech2Timestamp
|
||||
from funasr.bin.asr_infer import Speech2TextTransducer
|
||||
from funasr.bin.asr_infer import Speech2TextSAASR
|
||||
|
||||
def inference_asr(
|
||||
maxlenratio: float,
|
||||
@ -1444,6 +1445,167 @@ def inference_transducer(
|
||||
return _forward
|
||||
|
||||
|
||||
def inference_sa_asr(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
# data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
streaming: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
mc: bool = False,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
streaming=streaming,
|
||||
)
|
||||
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
|
||||
speech2text = Speech2TextSAASR(**speech2text_kwargs)
|
||||
|
||||
def _forward(data_path_and_name_and_type,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
mc=mc,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if output_path is not None:
|
||||
writer = DatadirWriter(output_path)
|
||||
else:
|
||||
writer = None
|
||||
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
try:
|
||||
results = speech2text(**batch)
|
||||
except TooShortUttError as e:
|
||||
logging.warning(f"Utterance {keys} {e}")
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["sil"], [2], hyp]] * nbest
|
||||
|
||||
# Only supporting batch_size==1
|
||||
key = keys[0]
|
||||
for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
if writer is not None:
|
||||
ibest_writer = writer[f"{n}best_recog"]
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
ibest_writer["text_id"][key] = text_id
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text
|
||||
|
||||
logging.info("uttid: {}".format(key))
|
||||
logging.info("text predictions: {}".format(text))
|
||||
logging.info("text_id predictions: {}\n".format(text_id))
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
|
||||
|
||||
def inference_launch(**kwargs):
|
||||
if 'mode' in kwargs:
|
||||
mode = kwargs['mode']
|
||||
@ -1464,6 +1626,8 @@ def inference_launch(**kwargs):
|
||||
return inference_mfcca(**kwargs)
|
||||
elif mode == "rnnt":
|
||||
return inference_transducer(**kwargs)
|
||||
elif mode == "sa_asr":
|
||||
return inference_sa_asr(**kwargs)
|
||||
else:
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user