mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
This commit is contained in:
commit
74464315c1
@ -216,6 +216,9 @@ def inference_launch(**kwargs):
|
||||
elif mode == "paraformer":
|
||||
from funasr.bin.asr_inference_paraformer import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
elif mode == "paraformer_streaming":
|
||||
from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
elif mode == "paraformer_vad":
|
||||
from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
|
||||
907
funasr/bin/asr_inference_paraformer_streaming.py
Normal file
907
funasr/bin/asr_inference_paraformer_streaming.py
Normal file
@ -0,0 +1,907 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import copy
|
||||
import os
|
||||
import codecs
|
||||
import tempfile
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Dict
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
|
||||
from funasr.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
||||
from funasr.tasks.lm import LMTask
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
frontend_conf: dict = None,
|
||||
hotword_list_or_file: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
if asr_model.ctc != None:
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
scorers.update(
|
||||
ctc=ctc
|
||||
)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, device
|
||||
)
|
||||
scorers["lm"] = lm.lm
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
# transducer is not supported now
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# 6. [Optional] Build hotword list from str, local file or url
|
||||
|
||||
is_use_lm = lm_weight != 0.0 and lm_file is not None
|
||||
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
|
||||
beam_search = None
|
||||
self.beam_search = beam_search
|
||||
logging.info(f"Beam_search: {self.beam_search}")
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
self.frontend = frontend
|
||||
self.encoder_downsampling_factor = 1
|
||||
if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
|
||||
self.encoder_downsampling_factor = 4
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||
begin_time: int = 0, end_time: int = None,
|
||||
):
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
self.asr_model.frontend = None
|
||||
else:
|
||||
feats = speech
|
||||
feats_len = speech_lengths
|
||||
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
enc, enc_len = self.asr_model.encode_chunk(**batch)
|
||||
if isinstance(enc, tuple):
|
||||
enc = enc[0]
|
||||
# assert len(enc) == 1, len(enc)
|
||||
enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
|
||||
|
||||
predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
|
||||
predictor_outs[2], predictor_outs[3]
|
||||
pre_token_length = pre_token_length.floor().long()
|
||||
if torch.max(pre_token_length) < 1:
|
||||
return []
|
||||
decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
|
||||
decoder_out = decoder_outs
|
||||
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
for i in range(b):
|
||||
x = enc[i, :enc_len[i], :]
|
||||
am_scores = decoder_out[i, :pre_token_length[i], :]
|
||||
if self.beam_search is not None:
|
||||
nbest_hyps = self.beam_search(
|
||||
x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
else:
|
||||
yseq = am_scores.argmax(dim=-1)
|
||||
score = am_scores.max(dim=-1)[0]
|
||||
score = torch.sum(score, dim=-1)
|
||||
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
||||
yseq = torch.tensor(
|
||||
[self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
|
||||
)
|
||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
||||
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
|
||||
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
|
||||
|
||||
# assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
class Speech2TextExport:
|
||||
"""Speech2TextExport class
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
frontend_conf: dict = None,
|
||||
hotword_list_or_file: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# 1. Build ASR model
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
token_list = asr_model.token_list
|
||||
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
# self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
self.frontend = frontend
|
||||
|
||||
model = Paraformer_export(asr_model, onnx=False)
|
||||
self.asr_model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
||||
):
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
self.asr_model.frontend = None
|
||||
else:
|
||||
feats = speech
|
||||
feats_len = speech_lengths
|
||||
|
||||
enc_len_batch_total = feats_len.sum()
|
||||
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
decoder_outs = self.asr_model(**batch)
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
for i in range(b):
|
||||
am_scores = decoder_out[i, :ys_pad_lens[i], :]
|
||||
|
||||
yseq = am_scores.argmax(dim=-1)
|
||||
score = am_scores.max(dim=-1)[0]
|
||||
score = torch.sum(score, dim=-1)
|
||||
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
||||
yseq = torch.tensor(
|
||||
yseq.tolist(), device=yseq.device
|
||||
)
|
||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
||||
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
|
||||
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def inference(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
streaming: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
|
||||
**kwargs,
|
||||
):
|
||||
inference_pipeline = inference_modelscope(
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
ngpu=ngpu,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
penalty=penalty,
|
||||
log_level=log_level,
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
raw_inputs=raw_inputs,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
key_file=key_file,
|
||||
word_lm_train_config=word_lm_train_config,
|
||||
bpemodel=bpemodel,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
streaming=streaming,
|
||||
output_dir=output_dir,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
ngram_weight=ngram_weight,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
|
||||
**kwargs,
|
||||
)
|
||||
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
||||
|
||||
|
||||
def inference_modelscope(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
# data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
output_dir: Optional[str] = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
export_mode = False
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
export_mode = param_dict.get("export_mode", False)
|
||||
else:
|
||||
hotword_list_or_file = None
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
batch_size = 1
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
hotword_list_or_file=hotword_list_or_file,
|
||||
)
|
||||
if export_mode:
|
||||
speech2text = Speech2TextExport(**speech2text_kwargs)
|
||||
else:
|
||||
speech2text = Speech2Text(**speech2text_kwargs)
|
||||
|
||||
def _forward(
|
||||
data_path_and_name_and_type,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
hotword_list_or_file = None
|
||||
if param_dict is not None:
|
||||
hotword_list_or_file = param_dict.get('hotword')
|
||||
if 'hotword' in kwargs:
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
if hotword_list_or_file is not None or 'hotword' in kwargs:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
if param_dict is not None:
|
||||
use_timestamp = param_dict.get('use_timestamp', True)
|
||||
else:
|
||||
use_timestamp = True
|
||||
|
||||
forward_time_total = 0.0
|
||||
length_total = 0.0
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
cache = None
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if output_path is not None:
|
||||
writer = DatadirWriter(output_path)
|
||||
else:
|
||||
writer = None
|
||||
if param_dict is not None and "cache" in param_dict:
|
||||
cache = param_dict["cache"]
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
logging.info("decoding, utt_id: {}".format(keys))
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
|
||||
time_beg = time.time()
|
||||
results = speech2text(cache=cache, **batch)
|
||||
if len(results) < 1:
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
|
||||
time_end = time.time()
|
||||
forward_time = time_end - time_beg
|
||||
lfr_factor = results[0][-1]
|
||||
length = results[0][-2]
|
||||
forward_time_total += forward_time
|
||||
length_total += length
|
||||
rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
|
||||
100 * forward_time / (
|
||||
length * lfr_factor))
|
||||
logging.info(rtf_cur)
|
||||
|
||||
for batch_id in range(_bs):
|
||||
result = [results[batch_id][:-2]]
|
||||
|
||||
key = keys[batch_id]
|
||||
for n, result in zip(range(1, nbest + 1), result):
|
||||
text, token, token_int, hyp = result[0], result[1], result[2], result[3]
|
||||
time_stamp = None if len(result) < 5 else result[4]
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
if writer is not None:
|
||||
ibest_writer = writer[f"{n}best_recog"]
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
ibest_writer["rtf"][key] = rtf_cur
|
||||
|
||||
if text is not None:
|
||||
if use_timestamp and time_stamp is not None:
|
||||
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
|
||||
else:
|
||||
postprocessed_result = postprocess_utils.sentence_postprocess(token)
|
||||
time_stamp_postprocessed = ""
|
||||
if len(postprocessed_result) == 3:
|
||||
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
|
||||
postprocessed_result[1], \
|
||||
postprocessed_result[2]
|
||||
else:
|
||||
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
if time_stamp_postprocessed != "":
|
||||
item['time_stamp'] = time_stamp_postprocessed
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
# asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
|
||||
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
|
||||
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
|
||||
forward_time_total,
|
||||
100 * forward_time_total / (
|
||||
length_total * lfr_factor))
|
||||
logging.info(rtf_avg)
|
||||
if writer is not None:
|
||||
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="ASR Decoding",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hotword",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="hotword file path or hotwords seperated by space"
|
||||
)
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=False,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--asr_train_config",
|
||||
type=str,
|
||||
help="ASR training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--asr_model_file",
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str,
|
||||
help="Global cmvn file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_train_config",
|
||||
type=str,
|
||||
help="LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_file",
|
||||
type=str,
|
||||
help="LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_train_config",
|
||||
type=str,
|
||||
help="Word LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_file",
|
||||
type=str,
|
||||
help="Word LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ngram_file",
|
||||
type=str,
|
||||
help="N-gram parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Beam-search related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
||||
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
||||
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|
||||
group.add_argument(
|
||||
"--maxlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain max output length. "
|
||||
"If maxlenratio=0.0 (default), it uses a end-detect "
|
||||
"function "
|
||||
"to automatically find maximum hypothesis lengths."
|
||||
"If maxlenratio<0.0, its absolute value is interpreted"
|
||||
"as a constant max output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--minlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain min output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ctc_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="CTC weight in joint decoding",
|
||||
)
|
||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||
group.add_argument("--streaming", type=str2bool, default=False)
|
||||
|
||||
group.add_argument(
|
||||
"--frontend_conf",
|
||||
default=None,
|
||||
help="",
|
||||
)
|
||||
group.add_argument("--raw_inputs", type=list, default=None)
|
||||
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
|
||||
|
||||
group = parser.add_argument_group("Text converter related")
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
choices=["char", "bpe", None],
|
||||
help="The token type for ASR model. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model path of sentencepiece. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
param_dict = {'hotword': args.hotword}
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
kwargs['param_dict'] = param_dict
|
||||
inference(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# from modelscope.pipelines import pipeline
|
||||
# from modelscope.utils.constant import Tasks
|
||||
#
|
||||
# inference_16k_pipline = pipeline(
|
||||
# task=Tasks.auto_speech_recognition,
|
||||
# model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
#
|
||||
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||
# print(rec_result)
|
||||
|
||||
@ -54,7 +54,7 @@ class Speech2Diarization:
|
||||
self,
|
||||
diar_train_config: Union[Path, str] = None,
|
||||
diar_model_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
streaming: bool = False,
|
||||
@ -114,9 +114,19 @@ class Speech2Diarization:
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
# process oov
|
||||
seq = np.array([int(x) for x in seq])
|
||||
new_seq = []
|
||||
for i, x in enumerate(seq):
|
||||
if x < 2 ** vec_dim:
|
||||
new_seq.append(x)
|
||||
else:
|
||||
idx_list = np.where(seq < 2 ** vec_dim)[0]
|
||||
idx = np.abs(idx_list - i).argmin()
|
||||
new_seq.append(seq[idx_list[idx]])
|
||||
return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
|
||||
|
||||
def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
|
||||
def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
|
||||
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
|
||||
# upsampling outputs to match inputs
|
||||
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
|
||||
@ -127,8 +137,14 @@ class Speech2Diarization:
|
||||
).squeeze(1).long()
|
||||
logits_idx = logits_idx[0].tolist()
|
||||
pse_labels = [self.token_list[x] for x in logits_idx]
|
||||
if output_format == "pse_labels":
|
||||
return pse_labels, None
|
||||
|
||||
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
|
||||
multi_labels = self.smooth_multi_labels(multi_labels)
|
||||
if output_format == "binary_labels":
|
||||
return multi_labels, None
|
||||
|
||||
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
|
||||
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
|
||||
results = OrderedDict()
|
||||
@ -149,6 +165,7 @@ class Speech2Diarization:
|
||||
self,
|
||||
speech: Union[torch.Tensor, np.ndarray],
|
||||
profile: Union[torch.Tensor, np.ndarray],
|
||||
output_format: str = "speaker_turn"
|
||||
):
|
||||
"""Inference
|
||||
|
||||
@ -178,7 +195,7 @@ class Speech2Diarization:
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
logits = self.diar_model.prediction_forward(**batch)
|
||||
results, pse_labels = self.post_processing(logits, profile.shape[1])
|
||||
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
|
||||
|
||||
return results, pse_labels
|
||||
|
||||
@ -367,7 +384,7 @@ def inference_modelscope(
|
||||
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
|
||||
logging.info("Start to diarize...")
|
||||
result_list = []
|
||||
for keys, batch in loader:
|
||||
for idx, (keys, batch) in enumerate(loader):
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
@ -385,6 +402,9 @@ def inference_modelscope(
|
||||
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
|
||||
pse_label_writer.flush()
|
||||
|
||||
if idx % 100 == 0:
|
||||
logging.info("Processing {:5d}: {}".format(idx, key))
|
||||
|
||||
if output_path is not None:
|
||||
output_writer.close()
|
||||
pse_label_writer.close()
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Dict
|
||||
from typing import Iterator
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import List
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
@ -129,7 +130,7 @@ class IterableESPnetDataset(IterableDataset):
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
if not isinstance(path_name_type_list[0], Tuple):
|
||||
if not isinstance(path_name_type_list[0], (Tuple, List)):
|
||||
path = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
|
||||
@ -90,6 +90,47 @@ class DecoderLayerSANM(nn.Module):
|
||||
tgt = self.norm1(tgt)
|
||||
tgt = self.feed_forward(tgt)
|
||||
|
||||
x = tgt
|
||||
if self.self_attn:
|
||||
if self.normalize_before:
|
||||
tgt = self.norm2(tgt)
|
||||
x, _ = self.self_attn(tgt, tgt_mask)
|
||||
x = residual + self.dropout(x)
|
||||
|
||||
if self.src_attn is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
|
||||
|
||||
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
cache (List[torch.Tensor]): List of cached tensors.
|
||||
Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
"""
|
||||
# tgt = self.dropout(tgt)
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
tgt = self.feed_forward(tgt)
|
||||
|
||||
x = tgt
|
||||
if self.self_attn:
|
||||
if self.normalize_before:
|
||||
@ -109,7 +150,6 @@ class DecoderLayerSANM(nn.Module):
|
||||
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
|
||||
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
@ -947,6 +987,65 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
)
|
||||
return logp.squeeze(0), state
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
tgt: torch.Tensor,
|
||||
cache: dict = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
x = tgt
|
||||
if cache["decode_fsmn"] is None:
|
||||
cache_layer_num = len(self.decoders)
|
||||
if self.decoders2 is not None:
|
||||
cache_layer_num += len(self.decoders2)
|
||||
new_cache = [None] * cache_layer_num
|
||||
else:
|
||||
new_cache = cache["decode_fsmn"]
|
||||
for i in range(self.att_layer_num):
|
||||
decoder = self.decoders[i]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, None, memory, None, cache=new_cache[i]
|
||||
)
|
||||
new_cache[i] = c_ret
|
||||
|
||||
if self.num_blocks - self.att_layer_num > 1:
|
||||
for i in range(self.num_blocks - self.att_layer_num):
|
||||
j = i + self.att_layer_num
|
||||
decoder = self.decoders2[i]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, None, memory, None, cache=new_cache[j]
|
||||
)
|
||||
new_cache[j] = c_ret
|
||||
|
||||
for decoder in self.decoders3:
|
||||
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
|
||||
x, None, memory, None, cache=None
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
cache["decode_fsmn"] = new_cache
|
||||
return x
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
|
||||
@ -325,6 +325,65 @@ class Paraformer(AbsESPnetModel):
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def encode_chunk(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
|
||||
feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
@ -333,6 +392,11 @@ class Paraformer(AbsESPnetModel):
|
||||
ignore_id=self.ignore_id)
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
def calc_predictor_chunk(self, encoder_out, cache=None):
|
||||
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
||||
|
||||
decoder_outs = self.decoder(
|
||||
@ -342,6 +406,14 @@ class Paraformer(AbsESPnetModel):
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
|
||||
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
||||
decoder_outs = self.decoder.forward_chunk(
|
||||
encoder_out, sematic_embeds, cache["decoder"]
|
||||
)
|
||||
decoder_out = decoder_outs
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -1459,4 +1531,4 @@ class ContextualParaformer(Paraformer):
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
return var_dict_torch_update
|
||||
|
||||
@ -59,7 +59,8 @@ class DiarSondModel(AbsESPnetModel):
|
||||
normalize_speech_speaker: bool = False,
|
||||
ignore_id: int = -1,
|
||||
speaker_discrimination_loss_weight: float = 1.0,
|
||||
inter_score_loss_weight: float = 0.0
|
||||
inter_score_loss_weight: float = 0.0,
|
||||
inputs_type: str = "raw",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
@ -86,14 +87,12 @@ class DiarSondModel(AbsESPnetModel):
|
||||
)
|
||||
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
|
||||
self.pse_embedding = self.generate_pse_embedding()
|
||||
# self.register_buffer("pse_embedding", pse_embedding)
|
||||
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
||||
# self.register_buffer("power_weight", power_weight)
|
||||
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
||||
# self.register_buffer("int_token_arr", int_token_arr)
|
||||
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
|
||||
self.inter_score_loss_weight = inter_score_loss_weight
|
||||
self.forward_steps = 0
|
||||
self.inputs_type = inputs_type
|
||||
|
||||
def generate_pse_embedding(self):
|
||||
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
|
||||
@ -125,9 +124,14 @@ class DiarSondModel(AbsESPnetModel):
|
||||
binary_labels: (Batch, frames, max_spk_num)
|
||||
binary_labels_lengths: (Batch,)
|
||||
"""
|
||||
assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
|
||||
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
|
||||
batch_size = speech.shape[0]
|
||||
self.forward_steps = self.forward_steps + 1
|
||||
if self.pse_embedding.device != speech.device:
|
||||
self.pse_embedding = self.pse_embedding.to(speech.device)
|
||||
self.power_weight = self.power_weight.to(speech.device)
|
||||
self.int_token_arr = self.int_token_arr.to(speech.device)
|
||||
|
||||
# 1. Network forward
|
||||
pred, inter_outputs = self.prediction_forward(
|
||||
speech, speech_lengths,
|
||||
@ -149,9 +153,13 @@ class DiarSondModel(AbsESPnetModel):
|
||||
# the sequence length of 'pred' might be slightly less than the
|
||||
# length of 'spk_labels'. Here we force them to be equal.
|
||||
length_diff_tolerance = 2
|
||||
length_diff = pse_labels.shape[1] - pred.shape[1]
|
||||
if 0 < length_diff <= length_diff_tolerance:
|
||||
pse_labels = pse_labels[:, 0: pred.shape[1]]
|
||||
length_diff = abs(pse_labels.shape[1] - pred.shape[1])
|
||||
if length_diff <= length_diff_tolerance:
|
||||
min_len = min(pred.shape[1], pse_labels.shape[1])
|
||||
pse_labels = pse_labels[:, :min_len]
|
||||
pred = pred[:, :min_len]
|
||||
cd_score = cd_score[:, :min_len]
|
||||
ci_score = ci_score[:, :min_len]
|
||||
|
||||
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
|
||||
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
|
||||
@ -299,7 +307,7 @@ class DiarSondModel(AbsESPnetModel):
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.encoder is not None:
|
||||
if self.encoder is not None and self.inputs_type == "raw":
|
||||
speech, speech_lengths = self.encode(speech, speech_lengths)
|
||||
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
|
||||
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
|
||||
|
||||
@ -347,6 +347,48 @@ class SANMEncoder(AbsEncoder):
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
def forward_chunk(self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
cache: dict = None,
|
||||
ctc: CTC = None,
|
||||
):
|
||||
xs_pad *= self.output_size() ** 0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
else:
|
||||
xs_pad = self.embed.forward_chunk(xs_pad, cache)
|
||||
|
||||
encoder_outs = self.encoders0(xs_pad, None, None, None, None)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, None, None, None, None)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, None, None, None, None)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), None, None
|
||||
return xs_pad, ilens, None
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
|
||||
@ -199,6 +199,63 @@ class CifPredictorV2(nn.Module):
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def forward_chunk(self, hidden, cache=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
output = torch.relu(self.cif_conv1d(queries))
|
||||
output = output.transpose(1, 2)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
||||
|
||||
alphas = alphas.squeeze(-1)
|
||||
mask_chunk_predictor = None
|
||||
if cache is not None:
|
||||
mask_chunk_predictor = None
|
||||
mask_chunk_predictor = torch.zeros_like(alphas)
|
||||
mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0
|
||||
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
|
||||
if cache is not None:
|
||||
if cache["cif_hidden"] is not None:
|
||||
hidden = torch.cat((cache["cif_hidden"], hidden), 1)
|
||||
if cache["cif_alphas"] is not None:
|
||||
alphas = torch.cat((cache["cif_alphas"], alphas), -1)
|
||||
|
||||
token_num = alphas.sum(-1)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
len_time = alphas.size(-1)
|
||||
last_fire_place = len_time - 1
|
||||
last_fire_remainds = 0.0
|
||||
pre_alphas_length = 0
|
||||
|
||||
mask_chunk_peak_predictor = None
|
||||
if cache is not None:
|
||||
mask_chunk_peak_predictor = None
|
||||
mask_chunk_peak_predictor = torch.zeros_like(cif_peak)
|
||||
if cache["cif_alphas"] is not None:
|
||||
pre_alphas_length = cache["cif_alphas"].size(-1)
|
||||
mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0
|
||||
mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0
|
||||
|
||||
|
||||
if mask_chunk_peak_predictor is not None:
|
||||
cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1)
|
||||
|
||||
for i in range(len_time):
|
||||
if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
|
||||
last_fire_place = len_time - 1 - i
|
||||
last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
|
||||
break
|
||||
last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
|
||||
cache["cif_hidden"] = hidden[:, last_fire_place:, :]
|
||||
cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
|
||||
token_num_int = token_num.floor().type(torch.int32).item()
|
||||
return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
tail_threshold = self.tail_threshold
|
||||
|
||||
@ -347,15 +347,17 @@ class MultiHeadedAttentionSANM(nn.Module):
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
inputs = inputs * mask
|
||||
|
||||
inputs = inputs * mask
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x += inputs
|
||||
x = self.dropout(x)
|
||||
return x * mask
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
def forward_qkv(self, x):
|
||||
"""Transform query, key and value.
|
||||
@ -505,7 +507,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||
# print("in fsmn, cache is None, x", x.size())
|
||||
|
||||
x = self.pad_fn(x)
|
||||
if not self.training and t <= 1:
|
||||
if not self.training:
|
||||
cache = x
|
||||
else:
|
||||
# print("in fsmn, cache is not None, x", x.size())
|
||||
@ -513,7 +515,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||
# if t < self.kernel_size:
|
||||
# x = self.pad_fn(x)
|
||||
x = torch.cat((cache[:, :, 1:], x), dim=2)
|
||||
x = x[:, :, -self.kernel_size:]
|
||||
x = x[:, :, -(self.kernel_size+t-1):]
|
||||
# print("in fsmn, cache is not None, x_cat", x.size())
|
||||
cache = x
|
||||
x = self.fsmn_block(x)
|
||||
|
||||
@ -405,4 +405,13 @@ class SinusoidalPositionEncoder(torch.nn.Module):
|
||||
positions = torch.arange(1, timesteps+1)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
||||
|
||||
return x + position_encoding
|
||||
return x + position_encoding
|
||||
|
||||
def forward_chunk(self, x, cache=None):
|
||||
start_idx = 0
|
||||
batch_size, timesteps, input_dim = x.size()
|
||||
if cache is not None:
|
||||
start_idx = cache["start_idx"]
|
||||
positions = torch.arange(1, timesteps+start_idx+1)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
||||
return x + position_encoding[:, start_idx: start_idx + timesteps]
|
||||
|
||||
@ -507,7 +507,7 @@ class DiarTask(AbsTask):
|
||||
config_file: Union[Path, str] = None,
|
||||
model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
"""Build model from the files.
|
||||
|
||||
@ -562,6 +562,7 @@ class DiarTask(AbsTask):
|
||||
model.load_state_dict(model_dict)
|
||||
else:
|
||||
model_dict = torch.load(model_file, map_location=device)
|
||||
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
|
||||
model.load_state_dict(model_dict)
|
||||
if model_name_pth is not None and not os.path.exists(model_name_pth):
|
||||
torch.save(model_dict, model_name_pth)
|
||||
@ -569,6 +570,20 @@ class DiarTask(AbsTask):
|
||||
|
||||
return model, args
|
||||
|
||||
@classmethod
|
||||
def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
|
||||
from collections import OrderedDict
|
||||
new_dict = OrderedDict()
|
||||
for key, value in src_dict.items():
|
||||
if key in dest_dict:
|
||||
new_dict[key] = value
|
||||
else:
|
||||
logging.info("{} is no longer needed in this model.".format(key))
|
||||
for key, value in dest_dict.items():
|
||||
if key not in new_dict:
|
||||
logging.warning("{} is missed in checkpoint.".format(key))
|
||||
return new_dict
|
||||
|
||||
@classmethod
|
||||
def convert_tf2torch(
|
||||
cls,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user