Merge branch 'main' of github.com:alibaba-damo-academy/FunASR

add
This commit is contained in:
游雁 2023-03-16 20:40:55 +08:00
commit 74464315c1
12 changed files with 1258 additions and 23 deletions

View File

@ -216,6 +216,9 @@ def inference_launch(**kwargs):
elif mode == "paraformer":
from funasr.bin.asr_inference_paraformer import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_streaming":
from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "paraformer_vad":
from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
return inference_modelscope(**kwargs)

View File

@ -0,0 +1,907 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
import time
import copy
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
if asr_model.ctc != None:
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
scorers.update(
ctc=ctc
)
token_list = asr_model.token_list
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
# 6. [Optional] Build hotword list from str, local file or url
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
self.encoder_downsampling_factor = 1
if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
@torch.no_grad()
def __call__(
self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
begin_time: int = 0, end_time: int = None,
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.asr_model.encode_chunk(**batch)
if isinstance(enc, tuple):
enc = enc[0]
# assert len(enc) == 1, len(enc)
enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.floor().long()
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
decoder_out = decoder_outs
results = []
b, n, d = decoder_out.size()
for i in range(b):
x = enc[i, :enc_len[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
if self.beam_search is not None:
nbest_hyps = self.beam_search(
x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
class Speech2TextExport:
"""Speech2TextExport class
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
**kwargs,
):
# 1. Build ASR model
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
token_list = asr_model.token_list
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
# self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
model = Paraformer_export(asr_model, onnx=False)
self.asr_model = model
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
enc_len_batch_total = feats_len.sum()
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
decoder_outs = self.asr_model(**batch)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
for i in range(b):
am_scores = decoder_out[i, :ys_pad_lens[i], :]
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
yseq.tolist(), device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
batch_size = 1
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
if export_mode:
speech2text = Speech2TextExport(**speech2text_kwargs)
else:
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
if hotword_list_or_file is not None or 'hotword' in kwargs:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
forward_time_total = 0.0
length_total = 0.0
finish_count = 0
file_count = 1
cache = None
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
logging.info("decoding, utt_id: {}".format(keys))
# N-best list of (text, token, token_int, hyp_object)
time_beg = time.time()
results = speech2text(cache=cache, **batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]
length = results[0][-2]
forward_time_total += forward_time
length_total += length
rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
100 * forward_time / (
length * lfr_factor))
logging.info(rtf_cur)
for batch_id in range(_bs):
result = [results[batch_id][:-2]]
key = keys[batch_id]
for n, result in zip(range(1, nbest + 1), result):
text, token, token_int, hyp = result[0], result[1], result[2], result[3]
time_stamp = None if len(result) < 5 else result[4]
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
time_stamp_postprocessed = ""
if len(postprocessed_result) == 3:
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
item = {'key': key, 'value': text_postprocessed}
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text_postprocessed
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
forward_time_total,
100 * forward_time_total / (
length_total * lfr_factor))
logging.info(rtf_avg)
if writer is not None:
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--hotword",
type=str_or_none,
default=None,
help="hotword file path or hotwords seperated by space"
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument(
"--frontend_conf",
default=None,
help="",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
param_dict = {'hotword': args.hotword}
kwargs = vars(args)
kwargs.pop("config", None)
kwargs['param_dict'] = param_dict
inference(**kwargs)
if __name__ == "__main__":
main()
# from modelscope.pipelines import pipeline
# from modelscope.utils.constant import Tasks
#
# inference_16k_pipline = pipeline(
# task=Tasks.auto_speech_recognition,
# model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
#
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
# print(rec_result)

View File

@ -54,7 +54,7 @@ class Speech2Diarization:
self,
diar_train_config: Union[Path, str] = None,
diar_model_file: Union[Path, str] = None,
device: str = "cpu",
device: Union[str, torch.device] = "cpu",
batch_size: int = 1,
dtype: str = "float32",
streaming: bool = False,
@ -114,9 +114,19 @@ class Speech2Diarization:
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
# process oov
seq = np.array([int(x) for x in seq])
new_seq = []
for i, x in enumerate(seq):
if x < 2 ** vec_dim:
new_seq.append(x)
else:
idx_list = np.where(seq < 2 ** vec_dim)[0]
idx = np.abs(idx_list - i).argmin()
new_seq.append(seq[idx_list[idx]])
return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
# upsampling outputs to match inputs
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
@ -127,8 +137,14 @@ class Speech2Diarization:
).squeeze(1).long()
logits_idx = logits_idx[0].tolist()
pse_labels = [self.token_list[x] for x in logits_idx]
if output_format == "pse_labels":
return pse_labels, None
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
multi_labels = self.smooth_multi_labels(multi_labels)
if output_format == "binary_labels":
return multi_labels, None
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
results = OrderedDict()
@ -149,6 +165,7 @@ class Speech2Diarization:
self,
speech: Union[torch.Tensor, np.ndarray],
profile: Union[torch.Tensor, np.ndarray],
output_format: str = "speaker_turn"
):
"""Inference
@ -178,7 +195,7 @@ class Speech2Diarization:
batch = to_device(batch, device=self.device)
logits = self.diar_model.prediction_forward(**batch)
results, pse_labels = self.post_processing(logits, profile.shape[1])
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
return results, pse_labels
@ -367,7 +384,7 @@ def inference_modelscope(
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
logging.info("Start to diarize...")
result_list = []
for keys, batch in loader:
for idx, (keys, batch) in enumerate(loader):
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
@ -385,6 +402,9 @@ def inference_modelscope(
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
pse_label_writer.flush()
if idx % 100 == 0:
logging.info("Processing {:5d}: {}".format(idx, key))
if output_path is not None:
output_writer.close()
pse_label_writer.close()

View File

@ -8,6 +8,7 @@ from typing import Dict
from typing import Iterator
from typing import Tuple
from typing import Union
from typing import List
import kaldiio
import numpy as np
@ -129,7 +130,7 @@ class IterableESPnetDataset(IterableDataset):
non_iterable_list = []
self.path_name_type_list = []
if not isinstance(path_name_type_list[0], Tuple):
if not isinstance(path_name_type_list[0], (Tuple, List)):
path = path_name_type_list[0]
name = path_name_type_list[1]
_type = path_name_type_list[2]

View File

@ -90,6 +90,47 @@ class DecoderLayerSANM(nn.Module):
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
x, _ = self.self_attn(tgt, tgt_mask)
x = residual + self.dropout(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
# tgt = self.dropout(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.self_attn:
if self.normalize_before:
@ -109,7 +150,6 @@ class DecoderLayerSANM(nn.Module):
return x, tgt_mask, memory, memory_mask, cache
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
@ -947,6 +987,65 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
)
return logp.squeeze(0), state
def forward_chunk(
self,
memory: torch.Tensor,
tgt: torch.Tensor,
cache: dict = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
x = tgt
if cache["decode_fsmn"] is None:
cache_layer_num = len(self.decoders)
if self.decoders2 is not None:
cache_layer_num += len(self.decoders2)
new_cache = [None] * cache_layer_num
else:
new_cache = cache["decode_fsmn"]
for i in range(self.att_layer_num):
decoder = self.decoders[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, None, memory, None, cache=new_cache[i]
)
new_cache[i] = c_ret
if self.num_blocks - self.att_layer_num > 1:
for i in range(self.num_blocks - self.att_layer_num):
j = i + self.att_layer_num
decoder = self.decoders2[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, None, memory, None, cache=new_cache[j]
)
new_cache[j] = c_ret
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, None, memory, None, cache=None
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
cache["decode_fsmn"] = new_cache
return x
def forward_one_step(
self,
tgt: torch.Tensor,

View File

@ -325,6 +325,65 @@ class Paraformer(AbsESPnetModel):
return encoder_out, encoder_out_lens
def encode_chunk(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
@ -333,6 +392,11 @@ class Paraformer(AbsESPnetModel):
ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def calc_predictor_chunk(self, encoder_out, cache=None):
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
decoder_outs = self.decoder(
@ -342,6 +406,14 @@ class Paraformer(AbsESPnetModel):
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
)
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -1459,4 +1531,4 @@ class ContextualParaformer(Paraformer):
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update
return var_dict_torch_update

View File

@ -59,7 +59,8 @@ class DiarSondModel(AbsESPnetModel):
normalize_speech_speaker: bool = False,
ignore_id: int = -1,
speaker_discrimination_loss_weight: float = 1.0,
inter_score_loss_weight: float = 0.0
inter_score_loss_weight: float = 0.0,
inputs_type: str = "raw",
):
assert check_argument_types()
@ -86,14 +87,12 @@ class DiarSondModel(AbsESPnetModel):
)
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
self.pse_embedding = self.generate_pse_embedding()
# self.register_buffer("pse_embedding", pse_embedding)
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
# self.register_buffer("power_weight", power_weight)
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
# self.register_buffer("int_token_arr", int_token_arr)
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0
self.inputs_type = inputs_type
def generate_pse_embedding(self):
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
@ -125,9 +124,14 @@ class DiarSondModel(AbsESPnetModel):
binary_labels: (Batch, frames, max_spk_num)
binary_labels_lengths: (Batch,)
"""
assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
self.forward_steps = self.forward_steps + 1
if self.pse_embedding.device != speech.device:
self.pse_embedding = self.pse_embedding.to(speech.device)
self.power_weight = self.power_weight.to(speech.device)
self.int_token_arr = self.int_token_arr.to(speech.device)
# 1. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
@ -149,9 +153,13 @@ class DiarSondModel(AbsESPnetModel):
# the sequence length of 'pred' might be slightly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
length_diff = pse_labels.shape[1] - pred.shape[1]
if 0 < length_diff <= length_diff_tolerance:
pse_labels = pse_labels[:, 0: pred.shape[1]]
length_diff = abs(pse_labels.shape[1] - pred.shape[1])
if length_diff <= length_diff_tolerance:
min_len = min(pred.shape[1], pse_labels.shape[1])
pse_labels = pse_labels[:, :min_len]
pred = pred[:, :min_len]
cd_score = cd_score[:, :min_len]
ci_score = ci_score[:, :min_len]
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
@ -299,7 +307,7 @@ class DiarSondModel(AbsESPnetModel):
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.encoder is not None:
if self.encoder is not None and self.inputs_type == "raw":
speech, speech_lengths = self.encode(speech, speech_lengths)
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()

View File

@ -347,6 +347,48 @@ class SANMEncoder(AbsEncoder):
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
def forward_chunk(self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
cache: dict = None,
ctc: CTC = None,
):
xs_pad *= self.output_size() ** 0.5
if self.embed is None:
xs_pad = xs_pad
else:
xs_pad = self.embed.forward_chunk(xs_pad, cache)
encoder_outs = self.encoders0(xs_pad, None, None, None, None)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, None, None, None, None)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, None, None, None, None)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), None, None
return xs_pad, ilens, None
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf

View File

@ -199,6 +199,63 @@ class CifPredictorV2(nn.Module):
return acoustic_embeds, token_num, alphas, cif_peak
def forward_chunk(self, hidden, cache=None):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
output = torch.relu(self.cif_conv1d(queries))
output = output.transpose(1, 2)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
alphas = alphas.squeeze(-1)
mask_chunk_predictor = None
if cache is not None:
mask_chunk_predictor = None
mask_chunk_predictor = torch.zeros_like(alphas)
mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
if cache is not None:
if cache["cif_hidden"] is not None:
hidden = torch.cat((cache["cif_hidden"], hidden), 1)
if cache["cif_alphas"] is not None:
alphas = torch.cat((cache["cif_alphas"], alphas), -1)
token_num = alphas.sum(-1)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
len_time = alphas.size(-1)
last_fire_place = len_time - 1
last_fire_remainds = 0.0
pre_alphas_length = 0
mask_chunk_peak_predictor = None
if cache is not None:
mask_chunk_peak_predictor = None
mask_chunk_peak_predictor = torch.zeros_like(cif_peak)
if cache["cif_alphas"] is not None:
pre_alphas_length = cache["cif_alphas"].size(-1)
mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0
mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0
if mask_chunk_peak_predictor is not None:
cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1)
for i in range(len_time):
if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
last_fire_place = len_time - 1 - i
last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
break
last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
cache["cif_hidden"] = hidden[:, last_fire_place:, :]
cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
token_num_int = token_num.floor().type(torch.int32).item()
return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold

View File

@ -347,15 +347,17 @@ class MultiHeadedAttentionSANM(nn.Module):
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
return x * mask
if mask is not None:
x = x * mask
return x
def forward_qkv(self, x):
"""Transform query, key and value.
@ -505,7 +507,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
# print("in fsmn, cache is None, x", x.size())
x = self.pad_fn(x)
if not self.training and t <= 1:
if not self.training:
cache = x
else:
# print("in fsmn, cache is not None, x", x.size())
@ -513,7 +515,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
# if t < self.kernel_size:
# x = self.pad_fn(x)
x = torch.cat((cache[:, :, 1:], x), dim=2)
x = x[:, :, -self.kernel_size:]
x = x[:, :, -(self.kernel_size+t-1):]
# print("in fsmn, cache is not None, x_cat", x.size())
cache = x
x = self.fsmn_block(x)

View File

@ -405,4 +405,13 @@ class SinusoidalPositionEncoder(torch.nn.Module):
positions = torch.arange(1, timesteps+1)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
return x + position_encoding
return x + position_encoding
def forward_chunk(self, x, cache=None):
start_idx = 0
batch_size, timesteps, input_dim = x.size()
if cache is not None:
start_idx = cache["start_idx"]
positions = torch.arange(1, timesteps+start_idx+1)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
return x + position_encoding[:, start_idx: start_idx + timesteps]

View File

@ -507,7 +507,7 @@ class DiarTask(AbsTask):
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
device: Union[str, torch.device] = "cpu",
):
"""Build model from the files.
@ -562,6 +562,7 @@ class DiarTask(AbsTask):
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
@ -569,6 +570,20 @@ class DiarTask(AbsTask):
return model, args
@classmethod
def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
from collections import OrderedDict
new_dict = OrderedDict()
for key, value in src_dict.items():
if key in dest_dict:
new_dict[key] = value
else:
logging.info("{} is no longer needed in this model.".format(key))
for key, value in dest_dict.items():
if key not in new_dict:
logging.warning("{} is missed in checkpoint.".format(key))
return new_dict
@classmethod
def convert_tf2torch(
cls,