mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rnnt
This commit is contained in:
parent
8f3930494a
commit
fc606ceef3
@ -131,6 +131,11 @@ def get_parser():
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
group.add_argument(
|
||||
"--beam_search_config",
|
||||
default={},
|
||||
help="The keyword arguments for transducer beam search.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Beam-search related")
|
||||
group.add_argument(
|
||||
@ -168,6 +173,41 @@ def get_parser():
|
||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||
group.add_argument("--streaming", type=str2bool, default=False)
|
||||
group.add_argument("--simu_streaming", type=str2bool, default=False)
|
||||
group.add_argument("--chunk_size", type=int, default=16)
|
||||
group.add_argument("--left_context", type=int, default=16)
|
||||
group.add_argument("--right_context", type=int, default=0)
|
||||
group.add_argument(
|
||||
"--display_partial_hypotheses",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Dynamic quantization related")
|
||||
group.add_argument(
|
||||
"--quantize_asr_model",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Apply dynamic quantization to ASR model.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--quantize_modules",
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="""Module names to apply dynamic quantization on.
|
||||
The module names are provided as a list, where each name is separated
|
||||
by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
|
||||
Each specified name should be an attribute of 'torch.nn', e.g.:
|
||||
torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--quantize_dtype",
|
||||
type=str,
|
||||
default="qint8",
|
||||
choices=["float16", "qint8"],
|
||||
help="Dtype for dynamic quantization.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Text converter related")
|
||||
group.add_argument(
|
||||
@ -262,6 +302,9 @@ def inference_launch_funasr(**kwargs):
|
||||
elif mode == "mfcca":
|
||||
from funasr.bin.asr_inference_mfcca import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
elif mode == "rnnt":
|
||||
from funasr.bin.asr_inference_rnnt import inference
|
||||
return inference(**kwargs)
|
||||
else:
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
46
funasr/bin/asr_train_transducer.py
Executable file
46
funasr/bin/asr_train_transducer.py
Executable file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
|
||||
from funasr.tasks.asr_transducer import ASRTransducerTask
|
||||
|
||||
|
||||
# for ASR Training
|
||||
def parse_args():
|
||||
parser = ASRTransducerTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
# for ASR Training
|
||||
ASRTransducerTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
assert args.num_worker_count == 1
|
||||
|
||||
# re-compute batch size: when dataset type is small
|
||||
if args.dataset_type == "small":
|
||||
if args.batch_size is not None:
|
||||
args.batch_size = args.batch_size * args.ngpu
|
||||
if args.batch_bins is not None:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
main(args=args)
|
||||
0
funasr/models_transducer/__init__.py
Normal file
0
funasr/models_transducer/__init__.py
Normal file
213
funasr/models_transducer/activation.py
Normal file
213
funasr/models_transducer/activation.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""Activation functions for Transducer."""
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
|
||||
|
||||
def get_activation(
|
||||
activation_type: str,
|
||||
ftswish_threshold: float = -0.2,
|
||||
ftswish_mean_shift: float = 0.0,
|
||||
hardtanh_min_val: int = -1.0,
|
||||
hardtanh_max_val: int = 1.0,
|
||||
leakyrelu_neg_slope: float = 0.01,
|
||||
smish_alpha: float = 1.0,
|
||||
smish_beta: float = 1.0,
|
||||
softplus_beta: float = 1.0,
|
||||
softplus_threshold: int = 20,
|
||||
swish_beta: float = 1.0,
|
||||
) -> torch.nn.Module:
|
||||
"""Return activation function.
|
||||
|
||||
Args:
|
||||
activation_type: Activation function type.
|
||||
ftswish_threshold: Threshold value for FTSwish activation formulation.
|
||||
ftswish_mean_shift: Mean shifting value for FTSwish activation formulation.
|
||||
hardtanh_min_val: Minimum value of the linear region range for HardTanh.
|
||||
hardtanh_max_val: Maximum value of the linear region range for HardTanh.
|
||||
leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation.
|
||||
smish_alpha: Alpha value for Smish activation fomulation.
|
||||
smish_beta: Beta value for Smish activation formulation.
|
||||
softplus_beta: Beta value for softplus activation formulation in Mish.
|
||||
softplus_threshold: Values above this revert to a linear function in Mish.
|
||||
swish_beta: Beta value for Swish variant formulation.
|
||||
|
||||
Returns:
|
||||
: Activation function.
|
||||
|
||||
"""
|
||||
torch_version = V(torch.__version__)
|
||||
|
||||
activations = {
|
||||
"ftswish": (
|
||||
FTSwish,
|
||||
{"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift},
|
||||
),
|
||||
"hardtanh": (
|
||||
torch.nn.Hardtanh,
|
||||
{"min_val": hardtanh_min_val, "max_val": hardtanh_max_val},
|
||||
),
|
||||
"leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}),
|
||||
"mish": (
|
||||
Mish,
|
||||
{
|
||||
"softplus_beta": softplus_beta,
|
||||
"softplus_threshold": softplus_threshold,
|
||||
"use_builtin": torch_version >= V("1.9"),
|
||||
},
|
||||
),
|
||||
"relu": (torch.nn.ReLU, {}),
|
||||
"selu": (torch.nn.SELU, {}),
|
||||
"smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}),
|
||||
"swish": (
|
||||
Swish,
|
||||
{"beta": swish_beta, "use_builtin": torch_version >= V("1.8")},
|
||||
),
|
||||
"tanh": (torch.nn.Tanh, {}),
|
||||
"identity": (torch.nn.Identity, {}),
|
||||
}
|
||||
|
||||
act_func, act_args = activations[activation_type]
|
||||
|
||||
return act_func(**act_args)
|
||||
|
||||
|
||||
class FTSwish(torch.nn.Module):
|
||||
"""Flatten-T Swish activation definition.
|
||||
|
||||
FTSwish(x) = x * sigmoid(x) + threshold
|
||||
where FTSwish(x) < 0 = threshold
|
||||
|
||||
Reference: https://arxiv.org/abs/1812.06247
|
||||
|
||||
Args:
|
||||
threshold: Threshold value for FTSwish activation formulation. (threshold < 0)
|
||||
mean_shift: Mean shifting value for FTSwish activation formulation.
|
||||
(applied only if != 0, disabled by default)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert threshold < 0, "FTSwish threshold parameter should be < 0."
|
||||
|
||||
self.threshold = threshold
|
||||
self.mean_shift = mean_shift
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation."""
|
||||
x = (x * torch.sigmoid(x)) + self.threshold
|
||||
x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device))
|
||||
|
||||
if self.mean_shift != 0:
|
||||
x.sub_(self.mean_shift)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
"""Mish activation definition.
|
||||
|
||||
Mish(x) = x * tanh(softplus(x))
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681.
|
||||
|
||||
Args:
|
||||
softplus_beta: Beta value for softplus activation formulation.
|
||||
(Usually 0 > softplus_beta >= 2)
|
||||
softplus_threshold: Values above this revert to a linear function.
|
||||
(Usually 10 > softplus_threshold >= 20)
|
||||
use_builtin: Whether to use PyTorch activation function if available.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
softplus_beta: float = 1.0,
|
||||
softplus_threshold: int = 20,
|
||||
use_builtin: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if use_builtin:
|
||||
self.mish = torch.nn.Mish()
|
||||
else:
|
||||
self.tanh = torch.nn.Tanh()
|
||||
self.softplus = torch.nn.Softplus(
|
||||
beta=softplus_beta, threshold=softplus_threshold
|
||||
)
|
||||
|
||||
self.mish = lambda x: x * self.tanh(self.softplus(x))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation."""
|
||||
return self.mish(x)
|
||||
|
||||
|
||||
class Smish(torch.nn.Module):
|
||||
"""Smish activation definition.
|
||||
|
||||
Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x)))
|
||||
where alpha > 0 and beta > 0
|
||||
|
||||
Reference: https://www.mdpi.com/2079-9292/11/4/540/htm.
|
||||
|
||||
Args:
|
||||
alpha: Alpha value for Smish activation fomulation.
|
||||
(Usually, alpha = 1. If alpha <= 0, set value to 1).
|
||||
beta: Beta value for Smish activation formulation.
|
||||
(Usually, beta = 1. If beta <= 0, set value to 1).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tanh = torch.nn.Tanh()
|
||||
|
||||
self.alpha = alpha if alpha > 0 else 1
|
||||
self.beta = beta if beta > 0 else 1
|
||||
|
||||
self.smish = lambda x: (self.alpha * x) * self.tanh(
|
||||
torch.log(1 + torch.sigmoid((self.beta * x)))
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation."""
|
||||
return self.smish(x)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Swish activation definition.
|
||||
|
||||
Swish(x) = (beta * x) * sigmoid(x)
|
||||
where beta = 1 defines standard Swish activation.
|
||||
|
||||
References:
|
||||
https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
|
||||
E-swish variant: https://arxiv.org/abs/1801.07145.
|
||||
|
||||
Args:
|
||||
beta: Beta parameter for E-Swish.
|
||||
(beta >= 1. If beta < 1, use standard Swish).
|
||||
use_builtin: Whether to use PyTorch function if available.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.beta = beta
|
||||
|
||||
if beta > 1:
|
||||
self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
|
||||
else:
|
||||
if use_builtin:
|
||||
self.swish = torch.nn.SiLU()
|
||||
else:
|
||||
self.swish = lambda x: x * torch.sigmoid(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward computation."""
|
||||
return self.swish(x)
|
||||
705
funasr/models_transducer/beam_search_transducer.py
Normal file
705
funasr/models_transducer/beam_search_transducer.py
Normal file
@ -0,0 +1,705 @@
|
||||
"""Search algorithms for Transducer models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
"""Default hypothesis definition for Transducer search algorithms.
|
||||
|
||||
Args:
|
||||
score: Total log-probability.
|
||||
yseq: Label sequence as integer ID sequence.
|
||||
dec_state: RNNDecoder or StatelessDecoder state.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None) or None
|
||||
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
|
||||
|
||||
"""
|
||||
|
||||
score: float
|
||||
yseq: List[int]
|
||||
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
|
||||
lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtendedHypothesis(Hypothesis):
|
||||
"""Extended hypothesis definition for NSC beam search and mAES.
|
||||
|
||||
Args:
|
||||
: Hypothesis dataclass arguments.
|
||||
dec_out: Decoder output sequence. (B, D_dec)
|
||||
lm_score: Log-probabilities of the LM for given label. (vocab_size)
|
||||
|
||||
"""
|
||||
|
||||
dec_out: torch.Tensor = None
|
||||
lm_score: torch.Tensor = None
|
||||
|
||||
|
||||
class BeamSearchTransducer:
|
||||
"""Beam search implementation for Transducer.
|
||||
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint network module.
|
||||
beam_size: Size of the beam.
|
||||
lm: LM class.
|
||||
lm_weight: LM weight for soft fusion.
|
||||
search_type: Search algorithm to use during inference.
|
||||
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
|
||||
u_max: Maximum expected target sequence length. (ALSD)
|
||||
nstep: Number of maximum expansion steps at each time step. (mAES)
|
||||
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
|
||||
expansion_beta:
|
||||
Number of additional candidates for expanded hypotheses selection. (mAES)
|
||||
score_norm: Normalize final scores by length.
|
||||
nbest: Number of final hypothesis.
|
||||
streaming: Whether to perform chunk-by-chunk beam search.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: AbsDecoder,
|
||||
joint_network: JointNetwork,
|
||||
beam_size: int,
|
||||
lm: Optional[torch.nn.Module] = None,
|
||||
lm_weight: float = 0.1,
|
||||
search_type: str = "default",
|
||||
max_sym_exp: int = 3,
|
||||
u_max: int = 50,
|
||||
nstep: int = 2,
|
||||
expansion_gamma: float = 2.3,
|
||||
expansion_beta: int = 2,
|
||||
score_norm: bool = False,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
"""Construct a BeamSearchTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.vocab_size = decoder.vocab_size
|
||||
|
||||
assert beam_size <= self.vocab_size, (
|
||||
"beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
|
||||
% (
|
||||
beam_size,
|
||||
self.vocab_size,
|
||||
)
|
||||
)
|
||||
self.beam_size = beam_size
|
||||
|
||||
if search_type == "default":
|
||||
self.search_algorithm = self.default_beam_search
|
||||
elif search_type == "tsd":
|
||||
assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
|
||||
max_sym_exp
|
||||
)
|
||||
self.max_sym_exp = max_sym_exp
|
||||
|
||||
self.search_algorithm = self.time_sync_decoding
|
||||
elif search_type == "alsd":
|
||||
assert not streaming, "ALSD is not available in streaming mode."
|
||||
|
||||
assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
|
||||
self.u_max = u_max
|
||||
|
||||
self.search_algorithm = self.align_length_sync_decoding
|
||||
elif search_type == "maes":
|
||||
assert self.vocab_size >= beam_size + expansion_beta, (
|
||||
"beam_size (%d) + expansion_beta (%d) "
|
||||
" should be smaller than or equal to vocab size (%d)."
|
||||
% (beam_size, expansion_beta, self.vocab_size)
|
||||
)
|
||||
self.max_candidates = beam_size + expansion_beta
|
||||
|
||||
self.nstep = nstep
|
||||
self.expansion_gamma = expansion_gamma
|
||||
|
||||
self.search_algorithm = self.modified_adaptive_expansion_search
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Specified search type (%s) is not supported." % search_type
|
||||
)
|
||||
|
||||
self.use_lm = lm is not None
|
||||
|
||||
if self.use_lm:
|
||||
assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
|
||||
|
||||
self.sos = self.vocab_size - 1
|
||||
|
||||
self.lm = lm
|
||||
self.lm_weight = lm_weight
|
||||
|
||||
self.score_norm = score_norm
|
||||
self.nbest = nbest
|
||||
|
||||
self.reset_inference_cache()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
is_final: bool = True,
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D_enc)
|
||||
is_final: Whether enc_out is the final chunk of data.
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best decoding results
|
||||
|
||||
"""
|
||||
self.decoder.set_device(enc_out.device)
|
||||
|
||||
hyps = self.search_algorithm(enc_out)
|
||||
|
||||
if is_final:
|
||||
self.reset_inference_cache()
|
||||
|
||||
return self.sort_nbest(hyps)
|
||||
|
||||
self.search_cache = hyps
|
||||
|
||||
return hyps
|
||||
|
||||
def reset_inference_cache(self) -> None:
|
||||
"""Reset cache for decoder scoring and streaming."""
|
||||
self.decoder.score_cache = {}
|
||||
self.search_cache = None
|
||||
|
||||
def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
|
||||
"""Sort in-place hypotheses by score or score given sequence length.
|
||||
|
||||
Args:
|
||||
hyps: Hypothesis.
|
||||
|
||||
Return:
|
||||
hyps: Sorted hypothesis.
|
||||
|
||||
"""
|
||||
if self.score_norm:
|
||||
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
|
||||
else:
|
||||
hyps.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return hyps[: self.nbest]
|
||||
|
||||
def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
|
||||
"""Recombine hypotheses with same label ID sequence.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
final: Recombined hypotheses.
|
||||
|
||||
"""
|
||||
final = {}
|
||||
|
||||
for hyp in hyps:
|
||||
str_yseq = "_".join(map(str, hyp.yseq))
|
||||
|
||||
if str_yseq in final:
|
||||
final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
|
||||
else:
|
||||
final[str_yseq] = hyp
|
||||
|
||||
return [*final.values()]
|
||||
|
||||
def select_k_expansions(
|
||||
self,
|
||||
hyps: List[ExtendedHypothesis],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_logp: torch.Tensor,
|
||||
) -> List[ExtendedHypothesis]:
|
||||
"""Return K hypotheses candidates for expansion from a list of hypothesis.
|
||||
|
||||
K candidates are selected according to the extended hypotheses probabilities
|
||||
and a prune-by-value method. Where K is equal to beam_size + beta.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
topk_idx: Indices of candidates hypothesis.
|
||||
topk_logp: Log-probabilities of candidates hypothesis.
|
||||
|
||||
Returns:
|
||||
k_expansions: Best K expansion hypotheses candidates.
|
||||
|
||||
"""
|
||||
k_expansions = []
|
||||
|
||||
for i, hyp in enumerate(hyps):
|
||||
hyp_i = [
|
||||
(int(k), hyp.score + float(v))
|
||||
for k, v in zip(topk_idx[i], topk_logp[i])
|
||||
]
|
||||
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
|
||||
|
||||
k_expansions.append(
|
||||
sorted(
|
||||
filter(
|
||||
lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
|
||||
),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
return k_expansions
|
||||
|
||||
def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
|
||||
"""Make batch of inputs with left padding for LM scoring.
|
||||
|
||||
Args:
|
||||
hyps_seq: Hypothesis sequences.
|
||||
|
||||
Returns:
|
||||
: Padded batch of sequences.
|
||||
|
||||
"""
|
||||
max_len = max([len(h) for h in hyps_seq])
|
||||
|
||||
return torch.LongTensor(
|
||||
[[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
|
||||
device=self.decoder.device,
|
||||
)
|
||||
|
||||
def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Beam search implementation without prefix search.
|
||||
|
||||
Modified from https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
beam_k = min(self.beam_size, (self.vocab_size - 1))
|
||||
max_t = len(enc_out)
|
||||
|
||||
if self.search_cache is not None:
|
||||
kept_hyps = self.search_cache
|
||||
else:
|
||||
kept_hyps = [
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
yseq=[0],
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
for t in range(max_t):
|
||||
hyps = kept_hyps
|
||||
kept_hyps = []
|
||||
|
||||
while True:
|
||||
max_hyp = max(hyps, key=lambda x: x.score)
|
||||
hyps.remove(max_hyp)
|
||||
|
||||
label = torch.full(
|
||||
(1, 1),
|
||||
max_hyp.yseq[-1],
|
||||
dtype=torch.long,
|
||||
device=self.decoder.device,
|
||||
)
|
||||
dec_out, state = self.decoder.score(
|
||||
label,
|
||||
max_hyp.yseq,
|
||||
max_hyp.dec_state,
|
||||
)
|
||||
|
||||
logp = torch.log_softmax(
|
||||
self.joint_network(enc_out[t : t + 1, :], dec_out),
|
||||
dim=-1,
|
||||
).squeeze(0)
|
||||
top_k = logp[1:].topk(beam_k, dim=-1)
|
||||
|
||||
kept_hyps.append(
|
||||
Hypothesis(
|
||||
score=(max_hyp.score + float(logp[0:1])),
|
||||
yseq=max_hyp.yseq,
|
||||
dec_state=max_hyp.dec_state,
|
||||
lm_state=max_hyp.lm_state,
|
||||
)
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
lm_scores, lm_state = self.lm.score(
|
||||
torch.LongTensor(
|
||||
[self.sos] + max_hyp.yseq[1:], device=self.decoder.device
|
||||
),
|
||||
max_hyp.lm_state,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
lm_state = max_hyp.lm_state
|
||||
|
||||
for logp, k in zip(*top_k):
|
||||
score = max_hyp.score + float(logp)
|
||||
|
||||
if self.use_lm:
|
||||
score += self.lm_weight * lm_scores[k + 1]
|
||||
|
||||
hyps.append(
|
||||
Hypothesis(
|
||||
score=score,
|
||||
yseq=max_hyp.yseq + [int(k + 1)],
|
||||
dec_state=state,
|
||||
lm_state=lm_state,
|
||||
)
|
||||
)
|
||||
|
||||
hyps_max = float(max(hyps, key=lambda x: x.score).score)
|
||||
kept_most_prob = sorted(
|
||||
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
|
||||
key=lambda x: x.score,
|
||||
)
|
||||
if len(kept_most_prob) >= self.beam_size:
|
||||
kept_hyps = kept_most_prob
|
||||
break
|
||||
|
||||
return kept_hyps
|
||||
|
||||
def align_length_sync_decoding(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
) -> List[Hypothesis]:
|
||||
"""Alignment-length synchronous beam search implementation.
|
||||
|
||||
Based on https://ieeexplore.ieee.org/document/9053040
|
||||
|
||||
Args:
|
||||
h: Encoder output sequences. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
t_max = int(enc_out.size(0))
|
||||
u_max = min(self.u_max, (t_max - 1))
|
||||
|
||||
B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
|
||||
final = []
|
||||
|
||||
if self.use_lm:
|
||||
B[0].lm_state = self.lm.zero_state()
|
||||
|
||||
for i in range(t_max + u_max):
|
||||
A = []
|
||||
|
||||
B_ = []
|
||||
B_enc_out = []
|
||||
for hyp in B:
|
||||
u = len(hyp.yseq) - 1
|
||||
t = i - u
|
||||
|
||||
if t > (t_max - 1):
|
||||
continue
|
||||
|
||||
B_.append(hyp)
|
||||
B_enc_out.append((t, enc_out[t]))
|
||||
|
||||
if B_:
|
||||
beam_enc_out = torch.stack([b[1] for b in B_enc_out])
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(B_)
|
||||
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([b.yseq for b in B_]),
|
||||
[b.lm_state for b in B_],
|
||||
None,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(B_):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(beam_logp[i, 0])),
|
||||
yseq=hyp.yseq[:],
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
A.append(new_hyp)
|
||||
|
||||
if B_enc_out[i][0] == (t_max - 1):
|
||||
final.append(new_hyp)
|
||||
|
||||
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(logp)),
|
||||
yseq=(hyp.yseq[:] + [int(k)]),
|
||||
dec_state=self.decoder.select_state(beam_state, i),
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
||||
new_hyp.lm_state = beam_lm_states[i]
|
||||
|
||||
A.append(new_hyp)
|
||||
|
||||
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
B = self.recombine_hyps(B)
|
||||
|
||||
if final:
|
||||
return final
|
||||
|
||||
return B
|
||||
|
||||
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Time synchronous beam search implementation.
|
||||
|
||||
Based on https://ieeexplore.ieee.org/document/9053040
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
if self.search_cache is not None:
|
||||
B = self.search_cache
|
||||
else:
|
||||
B = [
|
||||
Hypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
if self.use_lm:
|
||||
B[0].lm_state = self.lm.zero_state()
|
||||
|
||||
for enc_out_t in enc_out:
|
||||
A = []
|
||||
C = B
|
||||
|
||||
enc_out_t = enc_out_t.unsqueeze(0)
|
||||
|
||||
for v in range(self.max_sym_exp):
|
||||
D = []
|
||||
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(C)
|
||||
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(enc_out_t, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
|
||||
|
||||
seq_A = [h.yseq for h in A]
|
||||
|
||||
for i, hyp in enumerate(C):
|
||||
if hyp.yseq not in seq_A:
|
||||
A.append(
|
||||
Hypothesis(
|
||||
score=(hyp.score + float(beam_logp[i, 0])),
|
||||
yseq=hyp.yseq[:],
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dict_pos = seq_A.index(hyp.yseq)
|
||||
|
||||
A[dict_pos].score = np.logaddexp(
|
||||
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
|
||||
)
|
||||
|
||||
if v < (self.max_sym_exp - 1):
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([c.yseq for c in C]),
|
||||
[c.lm_state for c in C],
|
||||
None,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(C):
|
||||
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
|
||||
new_hyp = Hypothesis(
|
||||
score=(hyp.score + float(logp)),
|
||||
yseq=(hyp.yseq + [int(k)]),
|
||||
dec_state=self.decoder.select_state(beam_state, i),
|
||||
lm_state=hyp.lm_state,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
|
||||
new_hyp.lm_state = beam_lm_states[i]
|
||||
|
||||
D.append(new_hyp)
|
||||
|
||||
C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
|
||||
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
|
||||
|
||||
return B
|
||||
|
||||
def modified_adaptive_expansion_search(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
) -> List[ExtendedHypothesis]:
|
||||
"""Modified version of Adaptive Expansion Search (mAES).
|
||||
|
||||
Based on AES (https://ieeexplore.ieee.org/document/9250505) and
|
||||
NSC (https://arxiv.org/abs/2201.05420).
|
||||
|
||||
Args:
|
||||
enc_out: Encoder output sequence. (T, D_enc)
|
||||
|
||||
Returns:
|
||||
nbest_hyps: N-best hypothesis.
|
||||
|
||||
"""
|
||||
if self.search_cache is not None:
|
||||
kept_hyps = self.search_cache
|
||||
else:
|
||||
init_tokens = [
|
||||
ExtendedHypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.init_state(1),
|
||||
)
|
||||
]
|
||||
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(
|
||||
init_tokens,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
|
||||
[h.lm_state for h in init_tokens],
|
||||
None,
|
||||
)
|
||||
|
||||
lm_state = beam_lm_states[0]
|
||||
lm_score = beam_lm_scores[0]
|
||||
else:
|
||||
lm_state = None
|
||||
lm_score = None
|
||||
|
||||
kept_hyps = [
|
||||
ExtendedHypothesis(
|
||||
yseq=[0],
|
||||
score=0.0,
|
||||
dec_state=self.decoder.select_state(beam_state, 0),
|
||||
dec_out=beam_dec_out[0],
|
||||
lm_state=lm_state,
|
||||
lm_score=lm_score,
|
||||
)
|
||||
]
|
||||
|
||||
for enc_out_t in enc_out:
|
||||
hyps = kept_hyps
|
||||
kept_hyps = []
|
||||
|
||||
beam_enc_out = enc_out_t.unsqueeze(0)
|
||||
|
||||
list_b = []
|
||||
for n in range(self.nstep):
|
||||
beam_dec_out = torch.stack([h.dec_out for h in hyps])
|
||||
|
||||
beam_logp, beam_idx = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
).topk(self.max_candidates, dim=-1)
|
||||
|
||||
k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
|
||||
|
||||
list_exp = []
|
||||
for i, hyp in enumerate(hyps):
|
||||
for k, new_score in k_expansions[i]:
|
||||
new_hyp = ExtendedHypothesis(
|
||||
yseq=hyp.yseq[:],
|
||||
score=new_score,
|
||||
dec_out=hyp.dec_out,
|
||||
dec_state=hyp.dec_state,
|
||||
lm_state=hyp.lm_state,
|
||||
lm_score=hyp.lm_score,
|
||||
)
|
||||
|
||||
if k == 0:
|
||||
list_b.append(new_hyp)
|
||||
else:
|
||||
new_hyp.yseq.append(int(k))
|
||||
|
||||
if self.use_lm:
|
||||
new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
|
||||
|
||||
list_exp.append(new_hyp)
|
||||
|
||||
if not list_exp:
|
||||
kept_hyps = sorted(
|
||||
self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
|
||||
)[: self.beam_size]
|
||||
|
||||
break
|
||||
else:
|
||||
beam_dec_out, beam_state = self.decoder.batch_score(
|
||||
list_exp,
|
||||
)
|
||||
|
||||
if self.use_lm:
|
||||
beam_lm_scores, beam_lm_states = self.lm.batch_score(
|
||||
self.create_lm_batch_inputs([h.yseq for h in list_exp]),
|
||||
[h.lm_state for h in list_exp],
|
||||
None,
|
||||
)
|
||||
|
||||
if n < (self.nstep - 1):
|
||||
for i, hyp in enumerate(list_exp):
|
||||
hyp.dec_out = beam_dec_out[i]
|
||||
hyp.dec_state = self.decoder.select_state(beam_state, i)
|
||||
|
||||
if self.use_lm:
|
||||
hyp.lm_state = beam_lm_states[i]
|
||||
hyp.lm_score = beam_lm_scores[i]
|
||||
|
||||
hyps = list_exp[:]
|
||||
else:
|
||||
beam_logp = torch.log_softmax(
|
||||
self.joint_network(beam_enc_out, beam_dec_out),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
for i, hyp in enumerate(list_exp):
|
||||
hyp.score += float(beam_logp[i, 0])
|
||||
|
||||
hyp.dec_out = beam_dec_out[i]
|
||||
hyp.dec_state = self.decoder.select_state(beam_state, i)
|
||||
|
||||
if self.use_lm:
|
||||
hyp.lm_state = beam_lm_states[i]
|
||||
hyp.lm_score = beam_lm_scores[i]
|
||||
|
||||
kept_hyps = sorted(
|
||||
self.recombine_hyps(list_b + list_exp),
|
||||
key=lambda x: x.score,
|
||||
reverse=True,
|
||||
)[: self.beam_size]
|
||||
|
||||
return kept_hyps
|
||||
0
funasr/models_transducer/decoder/__init__.py
Normal file
0
funasr/models_transducer/decoder/__init__.py
Normal file
110
funasr/models_transducer/decoder/abs_decoder.py
Normal file
110
funasr/models_transducer/decoder/abs_decoder.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Abstract decoder definition for Transducer models."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsDecoder(torch.nn.Module, ABC):
|
||||
"""Abstract decoder module."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, T, D_dec)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None) or None
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Any],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: torch.Tensor) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def select_state(
|
||||
self,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
idx: int = 0,
|
||||
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Get specified ID state from batch of states, if provided.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None) or None
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
259
funasr/models_transducer/decoder/rnn_decoder.py
Normal file
259
funasr/models_transducer/decoder/rnn_decoder.py
Normal file
@ -0,0 +1,259 @@
|
||||
"""RNN decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import Hypothesis
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
"""RNN decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
embed_size: Embedding size.
|
||||
hidden_size: Hidden size..
|
||||
rnn_type: Decoder layers type.
|
||||
num_layers: Number of decoder layers.
|
||||
dropout_rate: Dropout rate for decoder layers.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embedding padding symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
hidden_size: int = 256,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
|
||||
|
||||
self.rnn = torch.nn.ModuleList(
|
||||
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
|
||||
)
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
|
||||
|
||||
self.dropout_rnn = torch.nn.ModuleList(
|
||||
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.dlayers = num_layers
|
||||
self.dtype = rnn_type
|
||||
|
||||
self.output_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, U, D_dec)
|
||||
|
||||
"""
|
||||
if states is None:
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
def rnn_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: RNN input sequences. (B, D_emb)
|
||||
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
x: RNN output sequences. (B, D_dec)
|
||||
(h_next, c_next): Decoder hidden states.
|
||||
(N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_prev, c_prev = state
|
||||
h_next, c_next = self.init_state(x.size(0))
|
||||
|
||||
for layer in range(self.dlayers):
|
||||
if self.dtype == "lstm":
|
||||
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
|
||||
layer
|
||||
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
|
||||
else:
|
||||
x, h_next[layer : layer + 1] = self.rnn[layer](
|
||||
x, hx=h_prev[layer : layer + 1]
|
||||
)
|
||||
|
||||
x = self.dropout_rnn[layer](x)
|
||||
|
||||
return x, (h_next, c_next)
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_out, dec_state = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
|
||||
|
||||
self.score_cache[str_labels] = (dec_out, dec_state)
|
||||
|
||||
return dec_out[0], dec_state
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
states = self.create_batch_states([h.dec_state for h in hyps])
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
|
||||
return dec_out.squeeze(1), states
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.dtype == "lstm":
|
||||
c_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (h_n, c_n)
|
||||
|
||||
return (h_n, None)
|
||||
|
||||
def select_state(
|
||||
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
states[0][:, idx : idx + 1, :],
|
||||
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
|
||||
)
|
||||
|
||||
def create_batch_states(
|
||||
self,
|
||||
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Create decoder hidden states.
|
||||
|
||||
Args:
|
||||
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
|
||||
|
||||
Returns:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.cat([s[0] for s in new_states], dim=1),
|
||||
torch.cat([s[1] for s in new_states], dim=1)
|
||||
if self.dtype == "lstm"
|
||||
else None,
|
||||
)
|
||||
157
funasr/models_transducer/decoder/stateless_decoder.py
Normal file
157
funasr/models_transducer/decoder/stateless_decoder.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""Stateless decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import Hypothesis
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
|
||||
class StatelessDecoder(AbsDecoder):
|
||||
"""Stateless Transducer decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Output size.
|
||||
embed_size: Embedding size.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embed/Blank symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
use_embed_mask: bool = False,
|
||||
) -> None:
|
||||
"""Construct a StatelessDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
self.output_size = embed_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
self.use_embed_mask = use_embed_mask
|
||||
if self.use_embed_mask:
|
||||
self._embed_mask = SpecAug(
|
||||
time_mask_width_range=3,
|
||||
num_time_mask=1,
|
||||
apply_freq_mask=False,
|
||||
apply_time_warp=False
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states. None
|
||||
|
||||
Returns:
|
||||
dec_embed: Decoder output sequences. (B, U, D_emb)
|
||||
|
||||
"""
|
||||
dec_embed = self.embed_dropout_rate(self.embed(labels))
|
||||
if self.use_embed_mask and self.training:
|
||||
dec_embed = self._embed_mask(dec_embed, label_lens)[0]
|
||||
|
||||
return dec_embed
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
state: None,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
state: Previous decoder hidden states. None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_emb)
|
||||
state: Decoder hidden states. None
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_embed = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
|
||||
self.score_cache[str_labels] = dec_embed
|
||||
|
||||
return dec_embed[0], None
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. None
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
return dec_embed.squeeze(1), None
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(self, batch_size: int) -> None:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. None
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def select_state(self, states: Optional[torch.Tensor], idx: int) -> None:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. None
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. None
|
||||
|
||||
"""
|
||||
return None
|
||||
0
funasr/models_transducer/encoder/__init__.py
Normal file
0
funasr/models_transducer/encoder/__init__.py
Normal file
0
funasr/models_transducer/encoder/blocks/__init__.py
Normal file
0
funasr/models_transducer/encoder/blocks/__init__.py
Normal file
178
funasr/models_transducer/encoder/blocks/branchformer.py
Normal file
178
funasr/models_transducer/encoder/blocks/branchformer.py
Normal file
@ -0,0 +1,178 @@
|
||||
"""Branchformer block for Transducer encoder."""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Branchformer(torch.nn.Module):
|
||||
"""Branchformer module definition.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2207.02971.pdf
|
||||
|
||||
Args:
|
||||
block_size: Input/output size.
|
||||
linear_size: Linear layers' hidden size.
|
||||
self_att: Self-attention module instance.
|
||||
conv_mod: Convolution module instance.
|
||||
norm_class: Normalization class.
|
||||
norm_args: Normalization module arguments.
|
||||
dropout_rate: Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
linear_size: int,
|
||||
self_att: torch.nn.Module,
|
||||
conv_mod: torch.nn.Module,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
norm_args: Dict = {},
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""Construct a Branchformer object."""
|
||||
super().__init__()
|
||||
|
||||
self.self_att = self_att
|
||||
self.conv_mod = conv_mod
|
||||
|
||||
self.channel_proj1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(block_size, linear_size), torch.nn.GELU()
|
||||
)
|
||||
self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size)
|
||||
|
||||
self.merge_proj = torch.nn.Linear(block_size + block_size, block_size)
|
||||
|
||||
self.norm_self_att = norm_class(block_size, **norm_args)
|
||||
self.norm_mlp = norm_class(block_size, **norm_args)
|
||||
self.norm_final = norm_class(block_size, **norm_args)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
self.block_size = block_size
|
||||
self.linear_size = linear_size
|
||||
self.cache = None
|
||||
|
||||
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||
"""Initialize/Reset self-attention and convolution modules cache for streaming.
|
||||
|
||||
Args:
|
||||
left_context: Number of left frames during chunk-by-chunk inference.
|
||||
device: Device to use for cache tensor.
|
||||
|
||||
"""
|
||||
self.cache = [
|
||||
torch.zeros(
|
||||
(1, left_context, self.block_size),
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(
|
||||
1,
|
||||
self.linear_size // 2,
|
||||
self.conv_mod.kernel_size - 1,
|
||||
),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
|
||||
Args:
|
||||
x: Branchformer input sequences. (B, T, D_block)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
mask: Source mask. (B, T)
|
||||
chunk_mask: Chunk mask. (T_2, T_2)
|
||||
|
||||
Returns:
|
||||
x: Branchformer output sequences. (B, T, D_block)
|
||||
mask: Source mask. (B, T)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
|
||||
"""
|
||||
x1 = x
|
||||
x2 = x
|
||||
|
||||
x1 = self.norm_self_att(x1)
|
||||
|
||||
x1 = self.dropout(
|
||||
self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask)
|
||||
)
|
||||
|
||||
x2 = self.norm_mlp(x2)
|
||||
|
||||
x2 = self.channel_proj1(x2)
|
||||
x2, _ = self.conv_mod(x2)
|
||||
x2 = self.channel_proj2(x2)
|
||||
|
||||
x2 = self.dropout(x2)
|
||||
|
||||
x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1)))
|
||||
|
||||
x = self.norm_final(x)
|
||||
|
||||
return x, mask, pos_enc
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode chunk of input sequence.
|
||||
|
||||
Args:
|
||||
x: Branchformer input sequences. (B, T, D_block)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
mask: Source mask. (B, T_2)
|
||||
left_context: Number of frames in left context.
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: Branchformer output sequences. (B, T, D_block)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
|
||||
"""
|
||||
x1 = x
|
||||
x2 = x
|
||||
|
||||
x1 = self.norm_self_att(x1)
|
||||
|
||||
if left_context > 0:
|
||||
key = torch.cat([self.cache[0], x1], dim=1)
|
||||
else:
|
||||
key = x1
|
||||
val = key
|
||||
|
||||
if right_context > 0:
|
||||
att_cache = key[:, -(left_context + right_context) : -right_context, :]
|
||||
else:
|
||||
att_cache = key[:, -left_context:, :]
|
||||
|
||||
x1 = self.self_att(x1, key, val, pos_enc, mask=mask, left_context=left_context)
|
||||
|
||||
x2 = self.norm_mlp(x2)
|
||||
x2 = self.channel_proj1(x2)
|
||||
|
||||
x2, conv_cache = self.conv_mod(
|
||||
x2, cache=self.cache[1], right_context=right_context
|
||||
)
|
||||
|
||||
x2 = self.channel_proj2(x2)
|
||||
|
||||
x = x + self.merge_proj(torch.cat([x1, x2], dim=-1))
|
||||
|
||||
x = self.norm_final(x)
|
||||
self.cache = [att_cache, conv_cache]
|
||||
|
||||
return x, pos_enc
|
||||
198
funasr/models_transducer/encoder/blocks/conformer.py
Normal file
198
funasr/models_transducer/encoder/blocks/conformer.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""Conformer block for Transducer encoder."""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Conformer(torch.nn.Module):
|
||||
"""Conformer module definition.
|
||||
|
||||
Args:
|
||||
block_size: Input/output size.
|
||||
self_att: Self-attention module instance.
|
||||
feed_forward: Feed-forward module instance.
|
||||
feed_forward_macaron: Feed-forward module instance for macaron network.
|
||||
conv_mod: Convolution module instance.
|
||||
norm_class: Normalization module class.
|
||||
norm_args: Normalization module arguments.
|
||||
dropout_rate: Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
self_att: torch.nn.Module,
|
||||
feed_forward: torch.nn.Module,
|
||||
feed_forward_macaron: torch.nn.Module,
|
||||
conv_mod: torch.nn.Module,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
norm_args: Dict = {},
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""Construct a Conformer object."""
|
||||
super().__init__()
|
||||
|
||||
self.self_att = self_att
|
||||
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.feed_forward_scale = 0.5
|
||||
|
||||
self.conv_mod = conv_mod
|
||||
|
||||
self.norm_feed_forward = norm_class(block_size, **norm_args)
|
||||
self.norm_self_att = norm_class(block_size, **norm_args)
|
||||
|
||||
self.norm_macaron = norm_class(block_size, **norm_args)
|
||||
self.norm_conv = norm_class(block_size, **norm_args)
|
||||
self.norm_final = norm_class(block_size, **norm_args)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
self.block_size = block_size
|
||||
self.cache = None
|
||||
|
||||
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||
"""Initialize/Reset self-attention and convolution modules cache for streaming.
|
||||
|
||||
Args:
|
||||
left_context: Number of left frames during chunk-by-chunk inference.
|
||||
device: Device to use for cache tensor.
|
||||
|
||||
"""
|
||||
self.cache = [
|
||||
torch.zeros(
|
||||
(1, left_context, self.block_size),
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(
|
||||
1,
|
||||
self.block_size,
|
||||
self.conv_mod.kernel_size - 1,
|
||||
),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
|
||||
Args:
|
||||
x: Conformer input sequences. (B, T, D_block)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
mask: Source mask. (B, T)
|
||||
chunk_mask: Chunk mask. (T_2, T_2)
|
||||
|
||||
Returns:
|
||||
x: Conformer output sequences. (B, T, D_block)
|
||||
mask: Source mask. (B, T)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
|
||||
"""
|
||||
residual = x
|
||||
|
||||
x = self.norm_macaron(x)
|
||||
x = residual + self.feed_forward_scale * self.dropout(
|
||||
self.feed_forward_macaron(x)
|
||||
)
|
||||
|
||||
residual = x
|
||||
x = self.norm_self_att(x)
|
||||
x_q = x
|
||||
x = residual + self.dropout(
|
||||
self.self_att(
|
||||
x_q,
|
||||
x,
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_mask=chunk_mask,
|
||||
)
|
||||
)
|
||||
|
||||
residual = x
|
||||
|
||||
x = self.norm_conv(x)
|
||||
x, _ = self.conv_mod(x)
|
||||
x = residual + self.dropout(x)
|
||||
residual = x
|
||||
|
||||
x = self.norm_feed_forward(x)
|
||||
x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
|
||||
|
||||
x = self.norm_final(x)
|
||||
return x, mask, pos_enc
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode chunk of input sequence.
|
||||
|
||||
Args:
|
||||
x: Conformer input sequences. (B, T, D_block)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
mask: Source mask. (B, T_2)
|
||||
left_context: Number of frames in left context.
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: Conformer output sequences. (B, T, D_block)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
|
||||
|
||||
"""
|
||||
residual = x
|
||||
|
||||
x = self.norm_macaron(x)
|
||||
x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
|
||||
|
||||
residual = x
|
||||
x = self.norm_self_att(x)
|
||||
if left_context > 0:
|
||||
key = torch.cat([self.cache[0], x], dim=1)
|
||||
else:
|
||||
key = x
|
||||
val = key
|
||||
|
||||
if right_context > 0:
|
||||
att_cache = key[:, -(left_context + right_context) : -right_context, :]
|
||||
else:
|
||||
att_cache = key[:, -left_context:, :]
|
||||
x = residual + self.self_att(
|
||||
x,
|
||||
key,
|
||||
val,
|
||||
pos_enc,
|
||||
mask,
|
||||
left_context=left_context,
|
||||
)
|
||||
|
||||
residual = x
|
||||
x = self.norm_conv(x)
|
||||
x, conv_cache = self.conv_mod(
|
||||
x, cache=self.cache[1], right_context=right_context
|
||||
)
|
||||
x = residual + x
|
||||
residual = x
|
||||
|
||||
x = self.norm_feed_forward(x)
|
||||
x = residual + self.feed_forward_scale * self.feed_forward(x)
|
||||
|
||||
x = self.norm_final(x)
|
||||
self.cache = [att_cache, conv_cache]
|
||||
|
||||
return x, pos_enc
|
||||
221
funasr/models_transducer/encoder/blocks/conv1d.py
Normal file
221
funasr/models_transducer/encoder/blocks/conv1d.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""Conv1d block for Transducer encoder."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
"""Conv1d module definition.
|
||||
|
||||
Args:
|
||||
input_size: Input dimension.
|
||||
output_size: Output dimension.
|
||||
kernel_size: Size of the convolving kernel.
|
||||
stride: Stride of the convolution.
|
||||
dilation: Spacing between the kernel points.
|
||||
groups: Number of blocked connections from input channels to output channels.
|
||||
bias: Whether to add a learnable bias to the output.
|
||||
batch_norm: Whether to use batch normalization after convolution.
|
||||
relu: Whether to use a ReLU activation after convolution.
|
||||
causal: Whether to use causal convolution (set to True if streaming).
|
||||
dropout_rate: Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
kernel_size: Union[int, Tuple],
|
||||
stride: Union[int, Tuple] = 1,
|
||||
dilation: Union[int, Tuple] = 1,
|
||||
groups: Union[int, Tuple] = 1,
|
||||
bias: bool = True,
|
||||
batch_norm: bool = False,
|
||||
relu: bool = True,
|
||||
causal: bool = False,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""Construct a Conv1d object."""
|
||||
super().__init__()
|
||||
|
||||
if causal:
|
||||
self.lorder = kernel_size - 1
|
||||
stride = 1
|
||||
else:
|
||||
self.lorder = 0
|
||||
stride = stride
|
||||
|
||||
self.conv = torch.nn.Conv1d(
|
||||
input_size,
|
||||
output_size,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
if relu:
|
||||
self.relu_func = torch.nn.ReLU()
|
||||
|
||||
if batch_norm:
|
||||
self.bn = torch.nn.BatchNorm1d(output_size)
|
||||
|
||||
self.out_pos = torch.nn.Linear(input_size, output_size)
|
||||
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
|
||||
self.relu = relu
|
||||
self.batch_norm = batch_norm
|
||||
self.causal = causal
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = dilation * (kernel_size - 1)
|
||||
self.stride = stride
|
||||
|
||||
self.cache = None
|
||||
|
||||
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||
"""Initialize/Reset Conv1d cache for streaming.
|
||||
|
||||
Args:
|
||||
left_context: Number of left frames during chunk-by-chunk inference.
|
||||
device: Device to use for cache tensor.
|
||||
|
||||
"""
|
||||
self.cache = torch.zeros(
|
||||
(1, self.input_size, self.kernel_size - 1), device=device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
|
||||
Args:
|
||||
x: Conv1d input sequences. (B, T, D_in)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
|
||||
mask: Source mask. (B, T)
|
||||
chunk_mask: Chunk mask. (T_2, T_2)
|
||||
|
||||
Returns:
|
||||
x: Conv1d output sequences. (B, sub(T), D_out)
|
||||
mask: Source mask. (B, T) or (B, sub(T))
|
||||
pos_enc: Positional embedding sequences.
|
||||
(B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out)
|
||||
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
if self.lorder > 0:
|
||||
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
mask = self.create_new_mask(mask)
|
||||
pos_enc = self.create_new_pos_enc(pos_enc)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
if self.batch_norm:
|
||||
x = self.bn(x)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
if self.relu:
|
||||
x = self.relu_func(x)
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return x, mask, self.out_pos(pos_enc)
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode chunk of input sequence.
|
||||
|
||||
Args:
|
||||
x: Conv1d input sequences. (B, T, D_in)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
|
||||
mask: Source mask. (B, T)
|
||||
left_context: Number of frames in left context.
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: Conv1d output sequences. (B, T, D_out)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out)
|
||||
|
||||
"""
|
||||
x = torch.cat([self.cache, x.transpose(1, 2)], dim=2)
|
||||
|
||||
if right_context > 0:
|
||||
self.cache = x[:, :, -(self.lorder + right_context) : -right_context]
|
||||
else:
|
||||
self.cache = x[:, :, -self.lorder :]
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
if self.batch_norm:
|
||||
x = self.bn(x)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
if self.relu:
|
||||
x = self.relu_func(x)
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return x, self.out_pos(pos_enc)
|
||||
|
||||
def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create new mask for output sequences.
|
||||
|
||||
Args:
|
||||
mask: Mask of input sequences. (B, T)
|
||||
|
||||
Returns:
|
||||
mask: Mask of output sequences. (B, sub(T))
|
||||
|
||||
"""
|
||||
if self.padding != 0:
|
||||
mask = mask[:, : -self.padding]
|
||||
|
||||
return mask[:, :: self.stride]
|
||||
|
||||
def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor:
|
||||
"""Create new positional embedding vector.
|
||||
|
||||
Args:
|
||||
pos_enc: Input sequences positional embedding.
|
||||
(B, 2 * (T - 1), D_in)
|
||||
|
||||
Returns:
|
||||
pos_enc: Output sequences positional embedding.
|
||||
(B, 2 * (sub(T) - 1), D_in)
|
||||
|
||||
"""
|
||||
pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :]
|
||||
pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :]
|
||||
|
||||
if self.padding != 0:
|
||||
pos_enc_positive = pos_enc_positive[:, : -self.padding, :]
|
||||
pos_enc_negative = pos_enc_negative[:, : -self.padding, :]
|
||||
|
||||
pos_enc_positive = pos_enc_positive[:, :: self.stride, :]
|
||||
pos_enc_negative = pos_enc_negative[:, :: self.stride, :]
|
||||
|
||||
pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1)
|
||||
|
||||
return pos_enc
|
||||
226
funasr/models_transducer/encoder/blocks/conv_input.py
Normal file
226
funasr/models_transducer/encoder/blocks/conv_input.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""ConvInput block for Transducer encoder."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
|
||||
|
||||
|
||||
class ConvInput(torch.nn.Module):
|
||||
"""ConvInput module definition.
|
||||
|
||||
Args:
|
||||
input_size: Input size.
|
||||
conv_size: Convolution size.
|
||||
subsampling_factor: Subsampling factor.
|
||||
vgg_like: Whether to use a VGG-like network.
|
||||
output_size: Block output dimension.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
conv_size: Union[int, Tuple],
|
||||
subsampling_factor: int = 4,
|
||||
vgg_like: bool = True,
|
||||
output_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Construct a ConvInput object."""
|
||||
super().__init__()
|
||||
if vgg_like:
|
||||
if subsampling_factor == 1:
|
||||
conv_size1, conv_size2 = conv_size
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((1, 2)),
|
||||
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((1, 2)),
|
||||
)
|
||||
|
||||
output_proj = conv_size2 * ((input_size // 2) // 2)
|
||||
|
||||
self.subsampling_factor = 1
|
||||
|
||||
self.stride_1 = 1
|
||||
|
||||
self.create_new_mask = self.create_new_vgg_mask
|
||||
|
||||
else:
|
||||
conv_size1, conv_size2 = conv_size
|
||||
|
||||
kernel_1 = int(subsampling_factor / 2)
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((kernel_1, 2)),
|
||||
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((2, 2)),
|
||||
)
|
||||
|
||||
output_proj = conv_size2 * ((input_size // 2) // 2)
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
|
||||
self.create_new_mask = self.create_new_vgg_mask
|
||||
|
||||
self.stride_1 = kernel_1
|
||||
|
||||
else:
|
||||
if subsampling_factor == 1:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.kernel_2 = 3
|
||||
self.stride_2 = 1
|
||||
|
||||
self.create_new_mask = self.create_new_conv2d_mask
|
||||
|
||||
else:
|
||||
kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
|
||||
subsampling_factor,
|
||||
input_size,
|
||||
)
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
output_proj = conv_size * conv_2_output_size
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.kernel_2 = kernel_2
|
||||
self.stride_2 = stride_2
|
||||
|
||||
self.create_new_mask = self.create_new_conv2d_mask
|
||||
|
||||
self.vgg_like = vgg_like
|
||||
self.min_frame_length = 2
|
||||
|
||||
if output_size is not None:
|
||||
self.output = torch.nn.Linear(output_proj, output_size)
|
||||
self.output_size = output_size
|
||||
else:
|
||||
self.output = None
|
||||
self.output_size = output_proj
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
|
||||
Args:
|
||||
x: ConvInput input sequences. (B, T, D_feats)
|
||||
mask: Mask of input sequences. (B, 1, T)
|
||||
|
||||
Returns:
|
||||
x: ConvInput output sequences. (B, sub(T), D_out)
|
||||
mask: Mask of output sequences. (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
if mask is not None:
|
||||
mask = self.create_new_mask(mask)
|
||||
olens = max(mask.eq(0).sum(1))
|
||||
|
||||
b, t_input, f = x.size()
|
||||
x = x.unsqueeze(1) # (b. 1. t. f)
|
||||
if chunk_size is not None:
|
||||
max_input_length = int(
|
||||
chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) ))
|
||||
)
|
||||
x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
|
||||
x = list(x)
|
||||
x = torch.stack(x, dim=0)
|
||||
N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
|
||||
x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
|
||||
x = self.conv(x)
|
||||
|
||||
_, c, t, f = x.size()
|
||||
|
||||
if chunk_size is not None:
|
||||
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
|
||||
else:
|
||||
x = x.transpose(1, 2).contiguous().view(b, t, c * f)
|
||||
|
||||
if self.output is not None:
|
||||
x = self.output(x)
|
||||
|
||||
return x, mask[:,:olens][:,:x.size(1)]
|
||||
|
||||
def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create a new mask for VGG output sequences.
|
||||
|
||||
Args:
|
||||
mask: Mask of input sequences. (B, T)
|
||||
|
||||
Returns:
|
||||
mask: Mask of output sequences. (B, sub(T))
|
||||
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
|
||||
mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
|
||||
|
||||
vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
|
||||
mask = mask[:, :vgg2_t_len][:, ::2]
|
||||
else:
|
||||
mask = mask
|
||||
|
||||
return mask
|
||||
|
||||
def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create new conformer mask for Conv2d output sequences.
|
||||
|
||||
Args:
|
||||
mask: Mask of input sequences. (B, T)
|
||||
|
||||
Returns:
|
||||
mask: Mask of output sequences. (B, sub(T))
|
||||
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
|
||||
else:
|
||||
return mask
|
||||
|
||||
def get_size_before_subsampling(self, size: int) -> int:
|
||||
"""Return the original size before subsampling for a given size.
|
||||
|
||||
Args:
|
||||
size: Number of frames after subsampling.
|
||||
|
||||
Returns:
|
||||
: Number of frames before subsampling.
|
||||
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
if self.vgg_like:
|
||||
return ((size * 2) * self.stride_1) + 1
|
||||
|
||||
return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
|
||||
return size
|
||||
52
funasr/models_transducer/encoder/blocks/linear_input.py
Normal file
52
funasr/models_transducer/encoder/blocks/linear_input.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""LinearInput block for Transducer encoder."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
class LinearInput(torch.nn.Module):
|
||||
"""ConvInput module definition.
|
||||
|
||||
Args:
|
||||
input_size: Input size.
|
||||
conv_size: Convolution size.
|
||||
subsampling_factor: Subsampling factor.
|
||||
vgg_like: Whether to use a VGG-like network.
|
||||
output_size: Block output dimension.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: Optional[int] = None,
|
||||
subsampling_factor: int = 1,
|
||||
) -> None:
|
||||
"""Construct a ConvInput object."""
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(0.1),
|
||||
)
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.min_frame_length = 1
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, mask: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
x = self.embed(x)
|
||||
return x, mask
|
||||
|
||||
def get_size_before_subsampling(self, size: int) -> int:
|
||||
"""Return the original size before subsampling for a given size.
|
||||
|
||||
Args:
|
||||
size: Number of frames after subsampling.
|
||||
|
||||
Returns:
|
||||
: Number of frames before subsampling.
|
||||
|
||||
"""
|
||||
return size
|
||||
352
funasr/models_transducer/encoder/building.py
Normal file
352
funasr/models_transducer/encoder/building.py
Normal file
@ -0,0 +1,352 @@
|
||||
"""Set of methods to build Transducer encoder architecture."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from funasr.models_transducer.activation import get_activation
|
||||
from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
|
||||
from funasr.models_transducer.encoder.blocks.conformer import Conformer
|
||||
from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
|
||||
from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
|
||||
from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
|
||||
from funasr.models_transducer.encoder.modules.attention import ( # noqa: H301
|
||||
RelPositionMultiHeadedAttention,
|
||||
)
|
||||
from funasr.models_transducer.encoder.modules.convolution import ( # noqa: H301
|
||||
ConformerConvolution,
|
||||
ConvolutionalSpatialGatingUnit,
|
||||
)
|
||||
from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
|
||||
from funasr.models_transducer.encoder.modules.normalization import get_normalization
|
||||
from funasr.models_transducer.encoder.modules.positional_encoding import ( # noqa: H301
|
||||
RelPositionalEncoding,
|
||||
)
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward,
|
||||
)
|
||||
|
||||
|
||||
def build_main_parameters(
|
||||
pos_wise_act_type: str = "swish",
|
||||
conv_mod_act_type: str = "swish",
|
||||
pos_enc_dropout_rate: float = 0.0,
|
||||
pos_enc_max_len: int = 5000,
|
||||
simplified_att_score: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
conv_mod_norm_type: str = "layer_norm",
|
||||
after_norm_eps: Optional[float] = None,
|
||||
after_norm_partial: Optional[float] = None,
|
||||
dynamic_chunk_training: bool = False,
|
||||
short_chunk_threshold: float = 0.75,
|
||||
short_chunk_size: int = 25,
|
||||
left_chunk_size: int = 0,
|
||||
time_reduction_factor: int = 1,
|
||||
unified_model_training: bool = False,
|
||||
default_chunk_size: int = 16,
|
||||
jitter_range: int =4,
|
||||
**activation_parameters,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build encoder main parameters.
|
||||
|
||||
Args:
|
||||
pos_wise_act_type: Conformer position-wise feed-forward activation type.
|
||||
conv_mod_act_type: Conformer convolution module activation type.
|
||||
pos_enc_dropout_rate: Positional encoding dropout rate.
|
||||
pos_enc_max_len: Positional encoding maximum length.
|
||||
simplified_att_score: Whether to use simplified attention score computation.
|
||||
norm_type: X-former normalization module type.
|
||||
conv_mod_norm_type: Conformer convolution module normalization type.
|
||||
after_norm_eps: Epsilon value for the final normalization.
|
||||
after_norm_partial: Value for the final normalization with RMSNorm.
|
||||
dynamic_chunk_training: Whether to use dynamic chunk training.
|
||||
short_chunk_threshold: Threshold for dynamic chunk selection.
|
||||
short_chunk_size: Minimum number of frames during dynamic chunk training.
|
||||
left_chunk_size: Number of frames in left context.
|
||||
**activations_parameters: Parameters of the activation functions.
|
||||
(See espnet2/asr_transducer/activation.py)
|
||||
|
||||
Returns:
|
||||
: Main encoder parameters
|
||||
|
||||
"""
|
||||
main_params = {}
|
||||
|
||||
main_params["pos_wise_act"] = get_activation(
|
||||
pos_wise_act_type, **activation_parameters
|
||||
)
|
||||
|
||||
main_params["conv_mod_act"] = get_activation(
|
||||
conv_mod_act_type, **activation_parameters
|
||||
)
|
||||
|
||||
main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
|
||||
main_params["pos_enc_max_len"] = pos_enc_max_len
|
||||
|
||||
main_params["simplified_att_score"] = simplified_att_score
|
||||
|
||||
main_params["norm_type"] = norm_type
|
||||
main_params["conv_mod_norm_type"] = conv_mod_norm_type
|
||||
|
||||
(
|
||||
main_params["after_norm_class"],
|
||||
main_params["after_norm_args"],
|
||||
) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
|
||||
|
||||
main_params["dynamic_chunk_training"] = dynamic_chunk_training
|
||||
main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
|
||||
main_params["short_chunk_size"] = max(0, short_chunk_size)
|
||||
main_params["left_chunk_size"] = max(0, left_chunk_size)
|
||||
|
||||
main_params["unified_model_training"] = unified_model_training
|
||||
main_params["default_chunk_size"] = max(0, default_chunk_size)
|
||||
main_params["jitter_range"] = max(0, jitter_range)
|
||||
|
||||
main_params["time_reduction_factor"] = time_reduction_factor
|
||||
|
||||
return main_params
|
||||
|
||||
|
||||
def build_positional_encoding(
|
||||
block_size: int, configuration: Dict[str, Any]
|
||||
) -> RelPositionalEncoding:
|
||||
"""Build positional encoding block.
|
||||
|
||||
Args:
|
||||
block_size: Input/output size.
|
||||
configuration: Positional encoding configuration.
|
||||
|
||||
Returns:
|
||||
: Positional encoding module.
|
||||
|
||||
"""
|
||||
return RelPositionalEncoding(
|
||||
block_size,
|
||||
configuration.get("pos_enc_dropout_rate", 0.0),
|
||||
max_len=configuration.get("pos_enc_max_len", 5000),
|
||||
)
|
||||
|
||||
|
||||
def build_input_block(
|
||||
input_size: int,
|
||||
configuration: Dict[str, Union[str, int]],
|
||||
) -> ConvInput:
|
||||
"""Build encoder input block.
|
||||
|
||||
Args:
|
||||
input_size: Input size.
|
||||
configuration: Input block configuration.
|
||||
|
||||
Returns:
|
||||
: ConvInput block function.
|
||||
|
||||
"""
|
||||
if configuration["linear"]:
|
||||
return LinearInput(
|
||||
input_size,
|
||||
configuration["output_size"],
|
||||
configuration["subsampling_factor"],
|
||||
)
|
||||
else:
|
||||
return ConvInput(
|
||||
input_size,
|
||||
configuration["conv_size"],
|
||||
configuration["subsampling_factor"],
|
||||
vgg_like=configuration["vgg_like"],
|
||||
output_size=configuration["output_size"],
|
||||
)
|
||||
|
||||
|
||||
def build_branchformer_block(
|
||||
configuration: List[Dict[str, Any]],
|
||||
main_params: Dict[str, Any],
|
||||
) -> Conformer:
|
||||
"""Build Branchformer block.
|
||||
|
||||
Args:
|
||||
configuration: Branchformer block configuration.
|
||||
main_params: Encoder main parameters.
|
||||
|
||||
Returns:
|
||||
: Branchformer block function.
|
||||
|
||||
"""
|
||||
hidden_size = configuration["hidden_size"]
|
||||
linear_size = configuration["linear_size"]
|
||||
|
||||
dropout_rate = configuration.get("dropout_rate", 0.0)
|
||||
|
||||
conv_mod_norm_class, conv_mod_norm_args = get_normalization(
|
||||
main_params["conv_mod_norm_type"],
|
||||
eps=configuration.get("conv_mod_norm_eps"),
|
||||
partial=configuration.get("conv_mod_norm_partial"),
|
||||
)
|
||||
|
||||
conv_mod_args = (
|
||||
linear_size,
|
||||
configuration["conv_mod_kernel_size"],
|
||||
conv_mod_norm_class,
|
||||
conv_mod_norm_args,
|
||||
dropout_rate,
|
||||
main_params["dynamic_chunk_training"],
|
||||
)
|
||||
|
||||
mult_att_args = (
|
||||
configuration.get("heads", 4),
|
||||
hidden_size,
|
||||
configuration.get("att_dropout_rate", 0.0),
|
||||
main_params["simplified_att_score"],
|
||||
)
|
||||
|
||||
norm_class, norm_args = get_normalization(
|
||||
main_params["norm_type"],
|
||||
eps=configuration.get("norm_eps"),
|
||||
partial=configuration.get("norm_partial"),
|
||||
)
|
||||
|
||||
return lambda: Branchformer(
|
||||
hidden_size,
|
||||
linear_size,
|
||||
RelPositionMultiHeadedAttention(*mult_att_args),
|
||||
ConvolutionalSpatialGatingUnit(*conv_mod_args),
|
||||
norm_class=norm_class,
|
||||
norm_args=norm_args,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
|
||||
|
||||
def build_conformer_block(
|
||||
configuration: List[Dict[str, Any]],
|
||||
main_params: Dict[str, Any],
|
||||
) -> Conformer:
|
||||
"""Build Conformer block.
|
||||
|
||||
Args:
|
||||
configuration: Conformer block configuration.
|
||||
main_params: Encoder main parameters.
|
||||
|
||||
Returns:
|
||||
: Conformer block function.
|
||||
|
||||
"""
|
||||
hidden_size = configuration["hidden_size"]
|
||||
linear_size = configuration["linear_size"]
|
||||
|
||||
pos_wise_args = (
|
||||
hidden_size,
|
||||
linear_size,
|
||||
configuration.get("pos_wise_dropout_rate", 0.0),
|
||||
main_params["pos_wise_act"],
|
||||
)
|
||||
|
||||
conv_mod_norm_args = {
|
||||
"eps": configuration.get("conv_mod_norm_eps", 1e-05),
|
||||
"momentum": configuration.get("conv_mod_norm_momentum", 0.1),
|
||||
}
|
||||
|
||||
conv_mod_args = (
|
||||
hidden_size,
|
||||
configuration["conv_mod_kernel_size"],
|
||||
main_params["conv_mod_act"],
|
||||
conv_mod_norm_args,
|
||||
main_params["dynamic_chunk_training"] or main_params["unified_model_training"],
|
||||
)
|
||||
|
||||
mult_att_args = (
|
||||
configuration.get("heads", 4),
|
||||
hidden_size,
|
||||
configuration.get("att_dropout_rate", 0.0),
|
||||
main_params["simplified_att_score"],
|
||||
)
|
||||
|
||||
norm_class, norm_args = get_normalization(
|
||||
main_params["norm_type"],
|
||||
eps=configuration.get("norm_eps"),
|
||||
partial=configuration.get("norm_partial"),
|
||||
)
|
||||
|
||||
return lambda: Conformer(
|
||||
hidden_size,
|
||||
RelPositionMultiHeadedAttention(*mult_att_args),
|
||||
PositionwiseFeedForward(*pos_wise_args),
|
||||
PositionwiseFeedForward(*pos_wise_args),
|
||||
ConformerConvolution(*conv_mod_args),
|
||||
norm_class=norm_class,
|
||||
norm_args=norm_args,
|
||||
dropout_rate=configuration.get("dropout_rate", 0.0),
|
||||
)
|
||||
|
||||
|
||||
def build_conv1d_block(
|
||||
configuration: List[Dict[str, Any]],
|
||||
causal: bool,
|
||||
) -> Conv1d:
|
||||
"""Build Conv1d block.
|
||||
|
||||
Args:
|
||||
configuration: Conv1d block configuration.
|
||||
|
||||
Returns:
|
||||
: Conv1d block function.
|
||||
|
||||
"""
|
||||
return lambda: Conv1d(
|
||||
configuration["input_size"],
|
||||
configuration["output_size"],
|
||||
configuration["kernel_size"],
|
||||
stride=configuration.get("stride", 1),
|
||||
dilation=configuration.get("dilation", 1),
|
||||
groups=configuration.get("groups", 1),
|
||||
bias=configuration.get("bias", True),
|
||||
relu=configuration.get("relu", True),
|
||||
batch_norm=configuration.get("batch_norm", False),
|
||||
causal=causal,
|
||||
dropout_rate=configuration.get("dropout_rate", 0.0),
|
||||
)
|
||||
|
||||
|
||||
def build_body_blocks(
|
||||
configuration: List[Dict[str, Any]],
|
||||
main_params: Dict[str, Any],
|
||||
output_size: int,
|
||||
) -> MultiBlocks:
|
||||
"""Build encoder body blocks.
|
||||
|
||||
Args:
|
||||
configuration: Body blocks configuration.
|
||||
main_params: Encoder main parameters.
|
||||
output_size: Architecture output size.
|
||||
|
||||
Returns:
|
||||
MultiBlocks function encapsulation all encoder blocks.
|
||||
|
||||
"""
|
||||
fn_modules = []
|
||||
extended_conf = []
|
||||
|
||||
for c in configuration:
|
||||
if c.get("num_blocks") is not None:
|
||||
extended_conf += c["num_blocks"] * [
|
||||
{c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
|
||||
]
|
||||
else:
|
||||
extended_conf += [c]
|
||||
|
||||
for i, c in enumerate(extended_conf):
|
||||
block_type = c["block_type"]
|
||||
|
||||
if block_type == "branchformer":
|
||||
module = build_branchformer_block(c, main_params)
|
||||
elif block_type == "conformer":
|
||||
module = build_conformer_block(c, main_params)
|
||||
elif block_type == "conv1d":
|
||||
module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
fn_modules.append(module)
|
||||
|
||||
return MultiBlocks(
|
||||
[fn() for fn in fn_modules],
|
||||
output_size,
|
||||
norm_class=main_params["after_norm_class"],
|
||||
norm_args=main_params["after_norm_args"],
|
||||
)
|
||||
294
funasr/models_transducer/encoder/encoder.py
Normal file
294
funasr/models_transducer/encoder/encoder.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Encoder for Transducer model."""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models_transducer.encoder.building import (
|
||||
build_body_blocks,
|
||||
build_input_block,
|
||||
build_main_parameters,
|
||||
build_positional_encoding,
|
||||
)
|
||||
from funasr.models_transducer.encoder.validation import validate_architecture
|
||||
from funasr.models_transducer.utils import (
|
||||
TooShortUttError,
|
||||
check_short_utt,
|
||||
make_chunk_mask,
|
||||
make_source_mask,
|
||||
)
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Encoder module definition.
|
||||
|
||||
Args:
|
||||
input_size: Input size.
|
||||
body_conf: Encoder body configuration.
|
||||
input_conf: Encoder input configuration.
|
||||
main_conf: Encoder main configuration.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
body_conf: List[Dict[str, Any]],
|
||||
input_conf: Dict[str, Any] = {},
|
||||
main_conf: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
embed_size, output_size = validate_architecture(
|
||||
input_conf, body_conf, input_size
|
||||
)
|
||||
main_params = build_main_parameters(**main_conf)
|
||||
|
||||
self.embed = build_input_block(input_size, input_conf)
|
||||
self.pos_enc = build_positional_encoding(embed_size, main_params)
|
||||
self.encoders = build_body_blocks(body_conf, main_params, output_size)
|
||||
|
||||
self.output_size = output_size
|
||||
|
||||
self.dynamic_chunk_training = main_params["dynamic_chunk_training"]
|
||||
self.short_chunk_threshold = main_params["short_chunk_threshold"]
|
||||
self.short_chunk_size = main_params["short_chunk_size"]
|
||||
self.left_chunk_size = main_params["left_chunk_size"]
|
||||
|
||||
self.unified_model_training = main_params["unified_model_training"]
|
||||
self.default_chunk_size = main_params["default_chunk_size"]
|
||||
self.jitter_range = main_params["jitter_range"]
|
||||
|
||||
self.time_reduction_factor = main_params["time_reduction_factor"]
|
||||
|
||||
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
|
||||
"""Return the corresponding number of sample for a given chunk size, in frames.
|
||||
|
||||
Where size is the number of features frames after applying subsampling.
|
||||
|
||||
Args:
|
||||
size: Number of frames after subsampling.
|
||||
hop_length: Frontend's hop length
|
||||
|
||||
Returns:
|
||||
: Number of raw samples
|
||||
|
||||
"""
|
||||
return self.embed.get_size_before_subsampling(size) * hop_length
|
||||
|
||||
def get_encoder_input_size(self, size: int) -> int:
|
||||
"""Return the corresponding number of sample for a given chunk size, in frames.
|
||||
|
||||
Where size is the number of features frames after applying subsampling.
|
||||
|
||||
Args:
|
||||
size: Number of frames after subsampling.
|
||||
|
||||
Returns:
|
||||
: Number of raw samples
|
||||
|
||||
"""
|
||||
return self.embed.get_size_before_subsampling(size)
|
||||
|
||||
|
||||
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||
"""Initialize/Reset encoder streaming cache.
|
||||
|
||||
Args:
|
||||
left_context: Number of frames in left context.
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
return self.encoders.reset_streaming_cache(left_context, device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_len: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
|
||||
Args:
|
||||
x: Encoder input features. (B, T_in, F)
|
||||
x_len: Encoder input features lengths. (B,)
|
||||
|
||||
Returns:
|
||||
x: Encoder outputs. (B, T_out, D_enc)
|
||||
x_len: Encoder outputs lenghts. (B,)
|
||||
|
||||
"""
|
||||
short_status, limit_size = check_short_utt(
|
||||
self.embed.subsampling_factor, x.size(1)
|
||||
)
|
||||
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {x.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
x.size(1),
|
||||
limit_size,
|
||||
)
|
||||
|
||||
mask = make_source_mask(x_len)
|
||||
if self.unified_model_training:
|
||||
x, mask = self.embed(x, mask, self.default_chunk_size)
|
||||
else:
|
||||
x, mask = self.embed(x, mask)
|
||||
pos_enc = self.pos_enc(x)
|
||||
|
||||
if self.unified_model_training:
|
||||
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
|
||||
chunk_mask = make_chunk_mask(
|
||||
x.size(1),
|
||||
chunk_size,
|
||||
left_chunk_size=self.left_chunk_size,
|
||||
device=x.device,
|
||||
)
|
||||
x_utt = self.encoders(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_mask=None,
|
||||
)
|
||||
x_chunk = self.encoders(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_mask=chunk_mask,
|
||||
)
|
||||
|
||||
olens = mask.eq(0).sum(1)
|
||||
if self.time_reduction_factor > 1:
|
||||
x_utt = x_utt[:,::self.time_reduction_factor,:]
|
||||
x_chunk = x_chunk[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
return x_utt, x_chunk, olens
|
||||
|
||||
elif self.dynamic_chunk_training:
|
||||
max_len = x.size(1)
|
||||
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||
|
||||
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = (chunk_size % self.short_chunk_size) + 1
|
||||
|
||||
chunk_mask = make_chunk_mask(
|
||||
x.size(1),
|
||||
chunk_size,
|
||||
left_chunk_size=self.left_chunk_size,
|
||||
device=x.device,
|
||||
)
|
||||
else:
|
||||
chunk_mask = None
|
||||
x = self.encoders(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_mask=chunk_mask,
|
||||
)
|
||||
|
||||
olens = mask.eq(0).sum(1)
|
||||
if self.time_reduction_factor > 1:
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
return x, olens
|
||||
|
||||
def simu_chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_len: torch.Tensor,
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 32,
|
||||
right_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
short_status, limit_size = check_short_utt(
|
||||
self.embed.subsampling_factor, x.size(1)
|
||||
)
|
||||
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {x.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
x.size(1),
|
||||
limit_size,
|
||||
)
|
||||
|
||||
mask = make_source_mask(x_len)
|
||||
|
||||
x, mask = self.embed(x, mask, chunk_size)
|
||||
pos_enc = self.pos_enc(x)
|
||||
chunk_mask = make_chunk_mask(
|
||||
x.size(1),
|
||||
chunk_size,
|
||||
left_chunk_size=self.left_chunk_size,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
x = self.encoders(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_mask=chunk_mask,
|
||||
)
|
||||
olens = mask.eq(0).sum(1)
|
||||
if self.time_reduction_factor > 1:
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
|
||||
return x
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_len: torch.Tensor,
|
||||
processed_frames: torch.tensor,
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 32,
|
||||
right_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Encode input sequences as chunks.
|
||||
|
||||
Args:
|
||||
x: Encoder input features. (1, T_in, F)
|
||||
x_len: Encoder input features lengths. (1,)
|
||||
processed_frames: Number of frames already seen.
|
||||
left_context: Number of frames in left context.
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: Encoder outputs. (B, T_out, D_enc)
|
||||
|
||||
"""
|
||||
mask = make_source_mask(x_len)
|
||||
x, mask = self.embed(x, mask, None)
|
||||
|
||||
if left_context > 0:
|
||||
processed_mask = (
|
||||
torch.arange(left_context, device=x.device)
|
||||
.view(1, left_context)
|
||||
.flip(1)
|
||||
)
|
||||
processed_mask = processed_mask >= processed_frames
|
||||
mask = torch.cat([processed_mask, mask], dim=1)
|
||||
pos_enc = self.pos_enc(x, left_context=left_context)
|
||||
x = self.encoders.chunk_forward(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_size=chunk_size,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
)
|
||||
|
||||
if right_context > 0:
|
||||
x = x[:, 0:-right_context, :]
|
||||
|
||||
if self.time_reduction_factor > 1:
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
return x
|
||||
246
funasr/models_transducer/encoder/modules/attention.py
Normal file
246
funasr/models_transducer/encoder/modules/attention.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""Multi-Head attention layers with relative positional encoding."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(torch.nn.Module):
|
||||
"""RelPositionMultiHeadedAttention definition.
|
||||
|
||||
Args:
|
||||
num_heads: Number of attention heads.
|
||||
embed_size: Embedding size.
|
||||
dropout_rate: Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
embed_size: int,
|
||||
dropout_rate: float = 0.0,
|
||||
simplified_attention_score: bool = False,
|
||||
) -> None:
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super().__init__()
|
||||
|
||||
self.d_k = embed_size // num_heads
|
||||
self.num_heads = num_heads
|
||||
|
||||
assert self.d_k * num_heads == embed_size, (
|
||||
"embed_size (%d) must be divisible by num_heads (%d)",
|
||||
(embed_size, num_heads),
|
||||
)
|
||||
|
||||
self.linear_q = torch.nn.Linear(embed_size, embed_size)
|
||||
self.linear_k = torch.nn.Linear(embed_size, embed_size)
|
||||
self.linear_v = torch.nn.Linear(embed_size, embed_size)
|
||||
|
||||
self.linear_out = torch.nn.Linear(embed_size, embed_size)
|
||||
|
||||
if simplified_attention_score:
|
||||
self.linear_pos = torch.nn.Linear(embed_size, num_heads)
|
||||
|
||||
self.compute_att_score = self.compute_simplified_attention_score
|
||||
else:
|
||||
self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
|
||||
|
||||
self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
|
||||
self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
self.compute_att_score = self.compute_attention_score
|
||||
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.attn = None
|
||||
|
||||
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
|
||||
left_context: Number of frames in left context.
|
||||
|
||||
Returns:
|
||||
x: Output sequence. (B, H, T_1, T_2)
|
||||
|
||||
"""
|
||||
batch_size, n_heads, time1, n = x.shape
|
||||
time2 = time1 + left_context
|
||||
|
||||
batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
|
||||
|
||||
return x.as_strided(
|
||||
(batch_size, n_heads, time1, time2),
|
||||
(batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=(n_stride * (time1 - 1)),
|
||||
)
|
||||
|
||||
def compute_simplified_attention_score(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Simplified attention score computation.
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/pull/458
|
||||
|
||||
Args:
|
||||
query: Transformed query tensor. (B, H, T_1, d_k)
|
||||
key: Transformed key tensor. (B, H, T_2, d_k)
|
||||
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||
left_context: Number of frames in left context.
|
||||
|
||||
Returns:
|
||||
: Attention score. (B, H, T_1, T_2)
|
||||
|
||||
"""
|
||||
pos_enc = self.linear_pos(pos_enc)
|
||||
|
||||
matrix_ac = torch.matmul(query, key.transpose(2, 3))
|
||||
|
||||
matrix_bd = self.rel_shift(
|
||||
pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
|
||||
left_context=left_context,
|
||||
)
|
||||
|
||||
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
|
||||
|
||||
def compute_attention_score(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
left_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Attention score computation.
|
||||
|
||||
Args:
|
||||
query: Transformed query tensor. (B, H, T_1, d_k)
|
||||
key: Transformed key tensor. (B, H, T_2, d_k)
|
||||
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||
left_context: Number of frames in left context.
|
||||
|
||||
Returns:
|
||||
: Attention score. (B, H, T_1, T_2)
|
||||
|
||||
"""
|
||||
p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
||||
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
||||
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
|
||||
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
|
||||
|
||||
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
|
||||
|
||||
def forward_qkv(
|
||||
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query: Query tensor. (B, T_1, size)
|
||||
key: Key tensor. (B, T_2, size)
|
||||
v: Value tensor. (B, T_2, size)
|
||||
|
||||
Returns:
|
||||
q: Transformed query tensor. (B, H, T_1, d_k)
|
||||
k: Transformed key tensor. (B, H, T_2, d_k)
|
||||
v: Transformed value tensor. (B, H, T_2, d_k)
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
|
||||
q = (
|
||||
self.linear_q(query)
|
||||
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
self.linear_k(key)
|
||||
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
self.linear_v(value)
|
||||
.view(n_batch, -1, self.num_heads, self.d_k)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(
|
||||
self,
|
||||
value: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value: Transformed value. (B, H, T_2, d_k)
|
||||
scores: Attention score. (B, H, T_1, T_2)
|
||||
mask: Source mask. (B, T_2)
|
||||
chunk_mask: Chunk mask. (T_1, T_1)
|
||||
|
||||
Returns:
|
||||
attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
|
||||
|
||||
"""
|
||||
batch_size = scores.size(0)
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
if chunk_mask is not None:
|
||||
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
|
||||
scores = scores.masked_fill(mask, float("-inf"))
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
||||
|
||||
attn_output = self.dropout(self.attn)
|
||||
attn_output = torch.matmul(attn_output, value)
|
||||
|
||||
attn_output = self.linear_out(
|
||||
attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, -1, self.num_heads * self.d_k)
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute scaled dot product attention with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query: Query tensor. (B, T_1, size)
|
||||
key: Key tensor. (B, T_2, size)
|
||||
value: Value tensor. (B, T_2, size)
|
||||
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
|
||||
mask: Source mask. (B, T_2)
|
||||
chunk_mask: Chunk mask. (T_1, T_1)
|
||||
left_context: Number of frames in left context.
|
||||
|
||||
Returns:
|
||||
: Output tensor. (B, T_1, H * d_k)
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
|
||||
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
|
||||
196
funasr/models_transducer/encoder/modules/convolution.py
Normal file
196
funasr/models_transducer/encoder/modules/convolution.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""Convolution modules for X-former blocks."""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ConformerConvolution(torch.nn.Module):
|
||||
"""ConformerConvolution module definition.
|
||||
|
||||
Args:
|
||||
channels: The number of channels.
|
||||
kernel_size: Size of the convolving kernel.
|
||||
activation: Type of activation function.
|
||||
norm_args: Normalization module arguments.
|
||||
causal: Whether to use causal convolution (set to True if streaming).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
activation: torch.nn.Module = torch.nn.ReLU(),
|
||||
norm_args: Dict = {},
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ConformerConvolution object."""
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.pointwise_conv1 = torch.nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
if causal:
|
||||
self.lorder = kernel_size - 1
|
||||
padding = 0
|
||||
else:
|
||||
self.lorder = 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.depthwise_conv = torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
)
|
||||
self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
|
||||
self.pointwise_conv2 = torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cache: Optional[torch.Tensor] = None,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: ConformerConvolution input sequences. (B, T, D_hidden)
|
||||
cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: ConformerConvolution output sequences. (B, T, D_hidden)
|
||||
cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
|
||||
|
||||
"""
|
||||
x = self.pointwise_conv1(x.transpose(1, 2))
|
||||
x = torch.nn.functional.glu(x, dim=1)
|
||||
|
||||
if self.lorder > 0:
|
||||
if cache is None:
|
||||
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
x = torch.cat([cache, x], dim=2)
|
||||
|
||||
if right_context > 0:
|
||||
cache = x[:, :, -(self.lorder + right_context) : -right_context]
|
||||
else:
|
||||
cache = x[:, :, -self.lorder :]
|
||||
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x).transpose(1, 2)
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
|
||||
"""Convolutional Spatial Gating Unit module definition.
|
||||
|
||||
Args:
|
||||
size: Initial size to determine the number of channels.
|
||||
kernel_size: Size of the convolving kernel.
|
||||
norm_class: Normalization module class.
|
||||
norm_args: Normalization module arguments.
|
||||
dropout_rate: Dropout rate.
|
||||
causal: Whether to use causal convolution (set to True if streaming).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
kernel_size: int,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
norm_args: Dict = {},
|
||||
dropout_rate: float = 0.0,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
"""Construct a ConvolutionalSpatialGatingUnit object."""
|
||||
super().__init__()
|
||||
|
||||
channels = size // 2
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
if causal:
|
||||
self.lorder = kernel_size - 1
|
||||
padding = 0
|
||||
else:
|
||||
self.lorder = 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.conv = torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
)
|
||||
|
||||
self.norm = norm_class(channels, **norm_args)
|
||||
self.activation = torch.nn.Identity()
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cache: Optional[torch.Tensor] = None,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden)
|
||||
cache: ConvolutionalSpationGatingUnit input cache.
|
||||
(1, conv_kernel, D_hidden)
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: ConvolutionalSpatialGatingUnit output sequences. (B, T, D_hidden // 2)
|
||||
|
||||
"""
|
||||
x_r, x_g = x.chunk(2, dim=-1)
|
||||
|
||||
x_g = self.norm(x_g).transpose(1, 2)
|
||||
|
||||
if self.lorder > 0:
|
||||
if cache is None:
|
||||
x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
x_g = torch.cat([cache, x_g], dim=2)
|
||||
|
||||
if right_context > 0:
|
||||
cache = x_g[:, :, -(self.lorder + right_context) : -right_context]
|
||||
else:
|
||||
cache = x_g[:, :, -self.lorder :]
|
||||
|
||||
x_g = self.conv(x_g).transpose(1, 2)
|
||||
|
||||
x = self.dropout(x_r * self.activation(x_g))
|
||||
|
||||
return x, cache
|
||||
105
funasr/models_transducer/encoder/modules/multi_blocks.py
Normal file
105
funasr/models_transducer/encoder/modules/multi_blocks.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""MultiBlocks for encoder architecture."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiBlocks(torch.nn.Module):
|
||||
"""MultiBlocks definition.
|
||||
|
||||
Args:
|
||||
block_list: Individual blocks of the encoder architecture.
|
||||
output_size: Architecture output size.
|
||||
norm_class: Normalization module class.
|
||||
norm_args: Normalization module arguments.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_list: List[torch.nn.Module],
|
||||
output_size: int,
|
||||
norm_class: torch.nn.Module = torch.nn.LayerNorm,
|
||||
norm_args: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Construct a MultiBlocks object."""
|
||||
super().__init__()
|
||||
|
||||
self.blocks = torch.nn.ModuleList(block_list)
|
||||
self.norm_blocks = norm_class(output_size, **norm_args)
|
||||
|
||||
self.num_blocks = len(block_list)
|
||||
|
||||
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
|
||||
"""Initialize/Reset encoder streaming cache.
|
||||
|
||||
Args:
|
||||
left_context: Number of left frames during chunk-by-chunk inference.
|
||||
device: Device to use for cache tensor.
|
||||
|
||||
"""
|
||||
for idx in range(self.num_blocks):
|
||||
self.blocks[idx].reset_streaming_cache(left_context, device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward each block of the encoder architecture.
|
||||
|
||||
Args:
|
||||
x: MultiBlocks input sequences. (B, T, D_block_1)
|
||||
pos_enc: Positional embedding sequences.
|
||||
mask: Source mask. (B, T)
|
||||
chunk_mask: Chunk mask. (T_2, T_2)
|
||||
|
||||
Returns:
|
||||
x: Output sequences. (B, T, D_block_N)
|
||||
|
||||
"""
|
||||
for block_index, block in enumerate(self.blocks):
|
||||
x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
|
||||
|
||||
x = self.norm_blocks(x)
|
||||
|
||||
return x
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
pos_enc: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int = 0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward each block of the encoder architecture.
|
||||
|
||||
Args:
|
||||
x: MultiBlocks input sequences. (B, T, D_block_1)
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
|
||||
mask: Source mask. (B, T_2)
|
||||
left_context: Number of frames in left context.
|
||||
right_context: Number of frames in right context.
|
||||
|
||||
Returns:
|
||||
x: MultiBlocks output sequences. (B, T, D_block_N)
|
||||
|
||||
"""
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
x, pos_enc = block.chunk_forward(
|
||||
x,
|
||||
pos_enc,
|
||||
mask,
|
||||
chunk_size=chunk_size,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
)
|
||||
|
||||
x = self.norm_blocks(x)
|
||||
|
||||
return x
|
||||
170
funasr/models_transducer/encoder/modules/normalization.py
Normal file
170
funasr/models_transducer/encoder/modules/normalization.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""Normalization modules for X-former blocks."""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_normalization(
|
||||
normalization_type: str,
|
||||
eps: Optional[float] = None,
|
||||
partial: Optional[float] = None,
|
||||
) -> Tuple[torch.nn.Module, Dict]:
|
||||
"""Get normalization module and arguments given parameters.
|
||||
|
||||
Args:
|
||||
normalization_type: Normalization module type.
|
||||
eps: Value added to the denominator.
|
||||
partial: Value defining the part of the input used for RMS stats (RMSNorm).
|
||||
|
||||
Return:
|
||||
: Normalization module class
|
||||
: Normalization module arguments
|
||||
|
||||
"""
|
||||
norm = {
|
||||
"basic_norm": (
|
||||
BasicNorm,
|
||||
{"eps": eps if eps is not None else 0.25},
|
||||
),
|
||||
"layer_norm": (torch.nn.LayerNorm, {"eps": eps if eps is not None else 1e-12}),
|
||||
"rms_norm": (
|
||||
RMSNorm,
|
||||
{
|
||||
"eps": eps if eps is not None else 1e-05,
|
||||
"partial": partial if partial is not None else -1.0,
|
||||
},
|
||||
),
|
||||
"scale_norm": (
|
||||
ScaleNorm,
|
||||
{"eps": eps if eps is not None else 1e-05},
|
||||
),
|
||||
}
|
||||
|
||||
return norm[normalization_type]
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""BasicNorm module definition.
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/pull/288
|
||||
|
||||
Args:
|
||||
normalized_shape: Expected size.
|
||||
eps: Value added to the denominator for numerical stability.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int,
|
||||
eps: float = 0.25,
|
||||
) -> None:
|
||||
"""Construct a BasicNorm object."""
|
||||
super().__init__()
|
||||
|
||||
self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach())
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute basic normalization.
|
||||
|
||||
Args:
|
||||
x: Input sequences. (B, T, D_hidden)
|
||||
|
||||
Returns:
|
||||
: Output sequences. (B, T, D_hidden)
|
||||
|
||||
"""
|
||||
scales = (torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps.exp()) ** -0.5
|
||||
|
||||
return x * scales
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
"""RMSNorm module definition.
|
||||
|
||||
Reference: https://arxiv.org/pdf/1910.07467.pdf
|
||||
|
||||
Args:
|
||||
normalized_shape: Expected size.
|
||||
eps: Value added to the denominator for numerical stability.
|
||||
partial: Value defining the part of the input used for RMS stats.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int,
|
||||
eps: float = 1e-5,
|
||||
partial: float = 0.0,
|
||||
) -> None:
|
||||
"""Construct a RMSNorm object."""
|
||||
super().__init__()
|
||||
|
||||
self.normalized_shape = normalized_shape
|
||||
|
||||
self.partial = True if 0 < partial < 1 else False
|
||||
self.p = partial
|
||||
self.eps = eps
|
||||
|
||||
self.scale = torch.nn.Parameter(torch.ones(normalized_shape))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute RMS normalization.
|
||||
|
||||
Args:
|
||||
x: Input sequences. (B, T, D_hidden)
|
||||
|
||||
Returns:
|
||||
x: Output sequences. (B, T, D_hidden)
|
||||
|
||||
"""
|
||||
if self.partial:
|
||||
partial_size = int(self.normalized_shape * self.p)
|
||||
partial_x, _ = torch.split(
|
||||
x, [partial_size, self.normalized_shape - partial_size], dim=-1
|
||||
)
|
||||
|
||||
norm_x = partial_x.norm(2, dim=-1, keepdim=True)
|
||||
d_x = partial_size
|
||||
else:
|
||||
norm_x = x.norm(2, dim=-1, keepdim=True)
|
||||
d_x = self.normalized_shape
|
||||
|
||||
rms_x = norm_x * d_x ** (-1.0 / 2)
|
||||
x = self.scale * (x / (rms_x + self.eps))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ScaleNorm(torch.nn.Module):
|
||||
"""ScaleNorm module definition.
|
||||
|
||||
Reference: https://arxiv.org/pdf/1910.05895.pdf
|
||||
|
||||
Args:
|
||||
normalized_shape: Expected size.
|
||||
eps: Value added to the denominator for numerical stability.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None:
|
||||
"""Construct a ScaleNorm object."""
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.scale = torch.nn.Parameter(torch.tensor(normalized_shape**0.5))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute scale normalization.
|
||||
|
||||
Args:
|
||||
x: Input sequences. (B, T, D_hidden)
|
||||
|
||||
Returns:
|
||||
: Output sequences. (B, T, D_hidden)
|
||||
|
||||
"""
|
||||
norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
|
||||
|
||||
return x * norm
|
||||
@ -0,0 +1,91 @@
|
||||
"""Positional encoding modules."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.modules.embedding import _pre_hook
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding.
|
||||
|
||||
Args:
|
||||
size: Module size.
|
||||
max_len: Maximum input length.
|
||||
dropout_rate: Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct a RelativePositionalEncoding object."""
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
|
||||
self.pe = None
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
|
||||
"""Reset positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input sequences. (B, T, ?)
|
||||
left_context: Number of frames in left context.
|
||||
|
||||
"""
|
||||
time1 = x.size(1) + left_context
|
||||
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= time1 * 2 - 1:
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(device=x.device, dtype=x.dtype)
|
||||
return
|
||||
|
||||
pe_positive = torch.zeros(time1, self.size)
|
||||
pe_negative = torch.zeros(time1, self.size)
|
||||
|
||||
position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.size, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.size)
|
||||
)
|
||||
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
|
||||
self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
|
||||
dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input sequences. (B, T, ?)
|
||||
left_context: Number of frames in left context.
|
||||
|
||||
Returns:
|
||||
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
|
||||
|
||||
"""
|
||||
self.extend_pe(x, left_context=left_context)
|
||||
|
||||
time1 = x.size(1) + left_context
|
||||
|
||||
pos_enc = self.pe[
|
||||
:, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
|
||||
]
|
||||
pos_enc = self.dropout(pos_enc)
|
||||
|
||||
return pos_enc
|
||||
835
funasr/models_transducer/encoder/sanm_encoder.py
Normal file
835
funasr/models_transducer/encoder/sanm_encoder.py
Normal file
@ -0,0 +1,835 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerSANM, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad = xs_pad * self.output_size**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
encoder_outs = self.encoders0(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## encoder
|
||||
# cicd
|
||||
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (768,256),(1,256,768)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (768,),(768,)
|
||||
"{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (256,1,31),(1,31,256,1)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# ffn
|
||||
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
|
||||
}
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
if names[1] == "encoders0":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
name_q = name_q.replace("encoders0", "encoders")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "encoders":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
layeridx_bias = 1
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "after_norm":
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class SANMEncoderChunkOpt(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
chunk_size: Union[int, Sequence[int]] = (16,),
|
||||
stride: Union[int, Sequence[int]] = (10,),
|
||||
pad_left: Union[int, Sequence[int]] = (0,),
|
||||
time_reduction_factor: int = 1,
|
||||
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
|
||||
decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks - 1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
shfit_fsmn = (kernel_size - 1) // 2
|
||||
self.overlap_chunk_cls = overlap_chunk(
|
||||
chunk_size=chunk_size,
|
||||
stride=stride,
|
||||
pad_left=pad_left,
|
||||
shfit_fsmn=shfit_fsmn,
|
||||
encoder_att_look_back_factor=encoder_att_look_back_factor,
|
||||
decoder_att_look_back_factor=decoder_att_look_back_factor,
|
||||
)
|
||||
self.time_reduction_factor = time_reduction_factor
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad *= self.output_size ** 0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
mask_shfit_chunk, mask_att_chunk_encoder = None, None
|
||||
if self.overlap_chunk_cls is not None:
|
||||
ilens = masks.squeeze(1).sum(1)
|
||||
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
|
||||
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
|
||||
dtype=xs_pad.dtype)
|
||||
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
|
||||
xs_pad.size(0),
|
||||
dtype=xs_pad.dtype)
|
||||
|
||||
encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
|
||||
xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None)
|
||||
|
||||
if self.time_reduction_factor > 1:
|
||||
xs_pad = xs_pad[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
## encoder
|
||||
# cicd
|
||||
"{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (768,256),(1,256,768)
|
||||
"{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (768,),(768,)
|
||||
"{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (256,1,31),(1,31,256,1)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# ffn
|
||||
"{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
"{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
|
||||
}
|
||||
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
|
||||
var_dict_torch_update = dict()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
if names[1] == "encoders0":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
name_q = name_q.replace("encoders0", "encoders")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "encoders":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
layeridx_bias = 1
|
||||
layeridx += layeridx_bias
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "after_norm":
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
171
funasr/models_transducer/encoder/validation.py
Normal file
171
funasr/models_transducer/encoder/validation.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Set of methods to validate encoder architecture."""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from funasr.models_transducer.utils import sub_factor_to_params
|
||||
|
||||
|
||||
def validate_block_arguments(
|
||||
configuration: Dict[str, Any],
|
||||
block_id: int,
|
||||
previous_block_output: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""Validate block arguments.
|
||||
|
||||
Args:
|
||||
configuration: Architecture configuration.
|
||||
block_id: Block ID.
|
||||
previous_block_output: Previous block output size.
|
||||
|
||||
Returns:
|
||||
input_size: Block input size.
|
||||
output_size: Block output size.
|
||||
|
||||
"""
|
||||
block_type = configuration.get("block_type")
|
||||
|
||||
if block_type is None:
|
||||
raise ValueError(
|
||||
"Block %d in encoder doesn't have a type assigned. " % block_id
|
||||
)
|
||||
|
||||
if block_type in ["branchformer", "conformer"]:
|
||||
if configuration.get("linear_size") is None:
|
||||
raise ValueError(
|
||||
"Missing 'linear_size' argument for X-former block (ID: %d)" % block_id
|
||||
)
|
||||
|
||||
if configuration.get("conv_mod_kernel_size") is None:
|
||||
raise ValueError(
|
||||
"Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)"
|
||||
% block_id
|
||||
)
|
||||
|
||||
input_size = configuration.get("hidden_size")
|
||||
output_size = configuration.get("hidden_size")
|
||||
|
||||
elif block_type == "conv1d":
|
||||
output_size = configuration.get("output_size")
|
||||
|
||||
if output_size is None:
|
||||
raise ValueError(
|
||||
"Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id
|
||||
)
|
||||
|
||||
if configuration.get("kernel_size") is None:
|
||||
raise ValueError(
|
||||
"Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id
|
||||
)
|
||||
|
||||
input_size = configuration["input_size"] = previous_block_output
|
||||
else:
|
||||
raise ValueError("Block type: %s is not supported." % block_type)
|
||||
|
||||
return input_size, output_size
|
||||
|
||||
|
||||
def validate_input_block(
|
||||
configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int
|
||||
) -> int:
|
||||
"""Validate input block.
|
||||
|
||||
Args:
|
||||
configuration: Encoder input block configuration.
|
||||
body_first_conf: Encoder first body block configuration.
|
||||
input_size: Encoder input block input size.
|
||||
|
||||
Return:
|
||||
output_size: Encoder input block output size.
|
||||
|
||||
"""
|
||||
vgg_like = configuration.get("vgg_like", False)
|
||||
linear = configuration.get("linear", False)
|
||||
next_block_type = body_first_conf.get("block_type")
|
||||
allowed_next_block_type = ["branchformer", "conformer", "conv1d"]
|
||||
|
||||
if next_block_type is None or (next_block_type not in allowed_next_block_type):
|
||||
return -1
|
||||
|
||||
if configuration.get("subsampling_factor") is None:
|
||||
configuration["subsampling_factor"] = 4
|
||||
|
||||
if vgg_like:
|
||||
conv_size = configuration.get("conv_size", (64, 128))
|
||||
|
||||
if isinstance(conv_size, int):
|
||||
conv_size = (conv_size, conv_size)
|
||||
else:
|
||||
conv_size = configuration.get("conv_size", None)
|
||||
|
||||
if isinstance(conv_size, tuple):
|
||||
conv_size = conv_size[0]
|
||||
|
||||
if next_block_type == "conv1d":
|
||||
if vgg_like:
|
||||
output_size = conv_size[1] * ((input_size // 2) // 2)
|
||||
else:
|
||||
if conv_size is None:
|
||||
conv_size = body_first_conf.get("output_size", 64)
|
||||
|
||||
sub_factor = configuration["subsampling_factor"]
|
||||
|
||||
_, _, conv_osize = sub_factor_to_params(sub_factor, input_size)
|
||||
assert (
|
||||
conv_osize > 0
|
||||
), "Conv2D output size is <1 with input size %d and subsampling %d" % (
|
||||
input_size,
|
||||
sub_factor,
|
||||
)
|
||||
|
||||
output_size = conv_osize * conv_size
|
||||
|
||||
configuration["output_size"] = None
|
||||
else:
|
||||
output_size = body_first_conf.get("hidden_size")
|
||||
|
||||
if conv_size is None:
|
||||
conv_size = output_size
|
||||
|
||||
configuration["output_size"] = output_size
|
||||
|
||||
configuration["conv_size"] = conv_size
|
||||
configuration["vgg_like"] = vgg_like
|
||||
configuration["linear"] = linear
|
||||
|
||||
return output_size
|
||||
|
||||
|
||||
def validate_architecture(
|
||||
input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Validate specified architecture is valid.
|
||||
|
||||
Args:
|
||||
input_conf: Encoder input block configuration.
|
||||
body_conf: Encoder body blocks configuration.
|
||||
input_size: Encoder input size.
|
||||
|
||||
Returns:
|
||||
input_block_osize: Encoder input block output size.
|
||||
: Encoder body block output size.
|
||||
|
||||
"""
|
||||
input_block_osize = validate_input_block(input_conf, body_conf[0], input_size)
|
||||
|
||||
cmp_io = []
|
||||
|
||||
for i, b in enumerate(body_conf):
|
||||
_io = validate_block_arguments(
|
||||
b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1]
|
||||
)
|
||||
|
||||
cmp_io.append(_io)
|
||||
|
||||
for i in range(1, len(cmp_io)):
|
||||
if cmp_io[(i - 1)][1] != cmp_io[i][0]:
|
||||
raise ValueError(
|
||||
"Output/Input mismatch between blocks %d and %d"
|
||||
" in the encoder body." % ((i - 1), i)
|
||||
)
|
||||
|
||||
return input_block_osize, cmp_io[-1][1]
|
||||
170
funasr/models_transducer/error_calculator.py
Normal file
170
funasr/models_transducer/error_calculator.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""Error Calculator module for Transducer."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
|
||||
|
||||
class ErrorCalculator:
|
||||
"""Calculate CER and WER for transducer models.
|
||||
|
||||
Args:
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
token_list: List of token units.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank symbol.
|
||||
report_cer: Whether to compute CER.
|
||||
report_wer: Whether to compute WER.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: AbsDecoder,
|
||||
joint_network: JointNetwork,
|
||||
token_list: List[int],
|
||||
sym_space: str,
|
||||
sym_blank: str,
|
||||
report_cer: bool = False,
|
||||
report_wer: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ErrorCalculatorTransducer object."""
|
||||
super().__init__()
|
||||
|
||||
self.beam_search = BeamSearchTransducer(
|
||||
decoder=decoder,
|
||||
joint_network=joint_network,
|
||||
beam_size=1,
|
||||
search_type="default",
|
||||
score_norm=False,
|
||||
)
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.token_list = token_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
def __call__(
|
||||
self, encoder_out: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Calculate sentence-level WER or/and CER score for Transducer model.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Returns:
|
||||
: Sentence-level CER score.
|
||||
: Sentence-level WER score.
|
||||
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
batchsize = int(encoder_out.size(0))
|
||||
|
||||
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
|
||||
|
||||
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
|
||||
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
char_pred, char_target = self.convert_to_char(pred, target)
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(char_pred, char_target)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(char_pred, char_target)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(
|
||||
self, pred: torch.Tensor, target: torch.Tensor
|
||||
) -> Tuple[List, List]:
|
||||
"""Convert label ID sequences to character sequences.
|
||||
|
||||
Args:
|
||||
pred: Prediction label ID sequences. (B, U)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Returns:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
|
||||
"""
|
||||
char_pred, char_target = [], []
|
||||
|
||||
for i, pred_i in enumerate(pred):
|
||||
char_pred_i = [self.token_list[int(h)] for h in pred_i]
|
||||
char_target_i = [self.token_list[int(r)] for r in target[i]]
|
||||
|
||||
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
|
||||
char_pred_i = char_pred_i.replace(self.blank, "")
|
||||
|
||||
char_target_i = "".join(char_target_i).replace(self.space, " ")
|
||||
char_target_i = char_target_i.replace(self.blank, "")
|
||||
|
||||
char_pred.append(char_pred_i)
|
||||
char_target.append(char_target_i)
|
||||
|
||||
return char_pred, char_target
|
||||
|
||||
def calculate_cer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
|
||||
Returns:
|
||||
: Average sentence-level CER score.
|
||||
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace(" ", "")
|
||||
target = char_target[i].replace(" ", "")
|
||||
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
|
||||
def calculate_wer(
|
||||
self, char_pred: torch.Tensor, char_target: torch.Tensor
|
||||
) -> float:
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
Args:
|
||||
char_pred: Prediction character sequences. (B, ?)
|
||||
char_target: Target character sequences. (B, ?)
|
||||
|
||||
Returns:
|
||||
: Average sentence-level WER score
|
||||
|
||||
"""
|
||||
import editdistance
|
||||
|
||||
distances, lens = [], []
|
||||
|
||||
for i, char_pred_i in enumerate(char_pred):
|
||||
pred = char_pred_i.replace("▁", " ").split()
|
||||
target = char_target[i].replace("▁", " ").split()
|
||||
|
||||
distances.append(editdistance.eval(pred, target))
|
||||
lens.append(len(target))
|
||||
|
||||
return float(sum(distances)) / sum(lens)
|
||||
484
funasr/models_transducer/espnet_transducer_model.py
Normal file
484
funasr/models_transducer/espnet_transducer_model.py
Normal file
@ -0,0 +1,484 @@
|
||||
"""ESPnet2 ASR Transducer model."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models_transducer.utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if V(torch.__version__) >= V("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRTransducerModel(AbsESPnetModel):
|
||||
"""ESPnet2ASRTransducerModel module definition.
|
||||
|
||||
Args:
|
||||
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
|
||||
token_list: List of token
|
||||
frontend: Frontend module.
|
||||
specaug: SpecAugment module.
|
||||
normalize: Normalization module.
|
||||
encoder: Encoder module.
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
transducer_weight: Weight of the Transducer loss.
|
||||
fastemit_lambda: FastEmit lambda value.
|
||||
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
|
||||
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
|
||||
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
|
||||
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
|
||||
ignore_id: Initial padding ID.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank Symbol
|
||||
report_cer: Whether to report Character Error Rate during validation.
|
||||
report_wer: Whether to report Word Error Rate during validation.
|
||||
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: Encoder,
|
||||
decoder: AbsDecoder,
|
||||
att_decoder: Optional[AbsAttDecoder],
|
||||
joint_network: JointNetwork,
|
||||
transducer_weight: float = 1.0,
|
||||
fastemit_lambda: float = 0.0,
|
||||
auxiliary_ctc_weight: float = 0.0,
|
||||
auxiliary_ctc_dropout_rate: float = 0.0,
|
||||
auxiliary_lm_loss_weight: float = 0.0,
|
||||
auxiliary_lm_loss_smoothing: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
) -> None:
|
||||
"""Construct an ESPnetASRTransducerModel object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.sym_space = sym_space
|
||||
self.sym_blank = sym_blank
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.criterion_transducer = None
|
||||
self.error_calculator = None
|
||||
|
||||
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
|
||||
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
|
||||
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
|
||||
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
|
||||
|
||||
self.transducer_weight = transducer_weight
|
||||
self.fastemit_lambda = fastemit_lambda
|
||||
|
||||
self.auxiliary_ctc_weight = auxiliary_ctc_weight
|
||||
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Forward architecture and compute loss(es).
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
loss: Main loss value.
|
||||
stats: Task statistics.
|
||||
weight: Task weights.
|
||||
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# 2. Transducer-related I/O preparation
|
||||
decoder_in, target, t_len, u_len = get_transducer_task_io(
|
||||
text,
|
||||
encoder_out_lens,
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
|
||||
# 3. Decoder
|
||||
self.decoder.set_device(encoder_out.device)
|
||||
decoder_out = self.decoder(decoder_in, u_len)
|
||||
|
||||
# 4. Joint Network
|
||||
joint_out = self.joint_network(
|
||||
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# 5. Losses
|
||||
loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
|
||||
encoder_out,
|
||||
joint_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
loss_ctc, loss_lm = 0.0, 0.0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
loss_ctc = self._calc_ctc_loss(
|
||||
encoder_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
loss_lm = self._calc_lm_loss(decoder_out, target)
|
||||
|
||||
loss = (
|
||||
self.transducer_weight * loss_trans
|
||||
+ self.auxiliary_ctc_weight * loss_ctc
|
||||
+ self.auxiliary_lm_loss_weight * loss_lm
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_transducer=loss_trans.detach(),
|
||||
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
|
||||
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
|
||||
cer_transducer=cer_trans,
|
||||
wer_transducer=wer_trans,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Collect features sequences and features lengths sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
{}: "feats": Features sequences. (B, T, D_feats),
|
||||
"feats_lengths": Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encoder speech sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
encoder_out: Encoder outputs. (B, T, D_enc)
|
||||
encoder_out_lens: Encoder outputs lengths. (B,)
|
||||
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract features sequences and features sequences lengths.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
feats: Features sequences. (B, T, D_feats)
|
||||
feats_lengths: Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return feats, feats_lengths
|
||||
|
||||
def _calc_transducer_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
joint_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
|
||||
"""Compute Transducer loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
joint_out: Joint Network output sequences (B, T, U, D_joint)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_transducer: Transducer loss value.
|
||||
cer_transducer: Character error rate for Transducer.
|
||||
wer_transducer: Word Error Rate for Transducer.
|
||||
|
||||
"""
|
||||
if self.criterion_transducer is None:
|
||||
try:
|
||||
# from warprnnt_pytorch import RNNTLoss
|
||||
# self.criterion_transducer = RNNTLoss(
|
||||
# reduction="mean",
|
||||
# fastemit_lambda=self.fastemit_lambda,
|
||||
# )
|
||||
from warp_rnnt import rnnt_loss as RNNTLoss
|
||||
self.criterion_transducer = RNNTLoss
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"warp-rnnt was not installed."
|
||||
"Please consult the installation documentation."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# loss_transducer = self.criterion_transducer(
|
||||
# joint_out,
|
||||
# target,
|
||||
# t_len,
|
||||
# u_len,
|
||||
# )
|
||||
log_probs = torch.log_softmax(joint_out, dim=-1)
|
||||
|
||||
loss_transducer = self.criterion_transducer(
|
||||
log_probs,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
reduction="mean",
|
||||
blank=self.blank_id,
|
||||
fastemit_lambda=self.fastemit_lambda,
|
||||
gather=True,
|
||||
)
|
||||
|
||||
if not self.training and (self.report_cer or self.report_wer):
|
||||
if self.error_calculator is None:
|
||||
from espnet2.asr_transducer.error_calculator import ErrorCalculator
|
||||
|
||||
self.error_calculator = ErrorCalculator(
|
||||
self.decoder,
|
||||
self.joint_network,
|
||||
self.token_list,
|
||||
self.sym_space,
|
||||
self.sym_blank,
|
||||
report_cer=self.report_cer,
|
||||
report_wer=self.report_wer,
|
||||
)
|
||||
|
||||
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
||||
|
||||
return loss_transducer, cer_transducer, wer_transducer
|
||||
|
||||
return loss_transducer, None, None
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_ctc: CTC loss value.
|
||||
|
||||
"""
|
||||
ctc_in = self.ctc_lin(
|
||||
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
|
||||
)
|
||||
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
|
||||
|
||||
target_mask = target != 0
|
||||
ctc_target = target[target_mask].cpu()
|
||||
|
||||
with torch.backends.cudnn.flags(deterministic=True):
|
||||
loss_ctc = torch.nn.functional.ctc_loss(
|
||||
ctc_in,
|
||||
ctc_target,
|
||||
t_len,
|
||||
u_len,
|
||||
zero_infinity=True,
|
||||
reduction="sum",
|
||||
)
|
||||
loss_ctc /= target.size(0)
|
||||
|
||||
return loss_ctc
|
||||
|
||||
def _calc_lm_loss(
|
||||
self,
|
||||
decoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute LM loss.
|
||||
|
||||
Args:
|
||||
decoder_out: Decoder output sequences. (B, U, D_dec)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Return:
|
||||
loss_lm: LM loss value.
|
||||
|
||||
"""
|
||||
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
|
||||
lm_target = target.view(-1).type(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
true_dist = lm_loss_in.clone()
|
||||
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
|
||||
|
||||
# Ignore blank ID (0)
|
||||
ignore = lm_target == 0
|
||||
lm_target = lm_target.masked_fill(ignore, 0)
|
||||
|
||||
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
|
||||
|
||||
loss_lm = torch.nn.functional.kl_div(
|
||||
torch.log_softmax(lm_loss_in, dim=1),
|
||||
true_dist,
|
||||
reduction="none",
|
||||
)
|
||||
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
|
||||
0
|
||||
)
|
||||
|
||||
return loss_lm
|
||||
485
funasr/models_transducer/espnet_transducer_model_uni_asr.py
Normal file
485
funasr/models_transducer/espnet_transducer_model_uni_asr.py
Normal file
@ -0,0 +1,485 @@
|
||||
"""ESPnet2 ASR Transducer model."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models_transducer.utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if V(torch.__version__) >= V("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class UniASRTransducerModel(AbsESPnetModel):
|
||||
"""ESPnet2ASRTransducerModel module definition.
|
||||
|
||||
Args:
|
||||
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
|
||||
token_list: List of token
|
||||
frontend: Frontend module.
|
||||
specaug: SpecAugment module.
|
||||
normalize: Normalization module.
|
||||
encoder: Encoder module.
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
transducer_weight: Weight of the Transducer loss.
|
||||
fastemit_lambda: FastEmit lambda value.
|
||||
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
|
||||
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
|
||||
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
|
||||
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
|
||||
ignore_id: Initial padding ID.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank Symbol
|
||||
report_cer: Whether to report Character Error Rate during validation.
|
||||
report_wer: Whether to report Word Error Rate during validation.
|
||||
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder,
|
||||
decoder: AbsDecoder,
|
||||
att_decoder: Optional[AbsAttDecoder],
|
||||
joint_network: JointNetwork,
|
||||
transducer_weight: float = 1.0,
|
||||
fastemit_lambda: float = 0.0,
|
||||
auxiliary_ctc_weight: float = 0.0,
|
||||
auxiliary_ctc_dropout_rate: float = 0.0,
|
||||
auxiliary_lm_loss_weight: float = 0.0,
|
||||
auxiliary_lm_loss_smoothing: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
) -> None:
|
||||
"""Construct an ESPnetASRTransducerModel object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.sym_space = sym_space
|
||||
self.sym_blank = sym_blank
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.criterion_transducer = None
|
||||
self.error_calculator = None
|
||||
|
||||
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
|
||||
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
|
||||
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
|
||||
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
|
||||
|
||||
self.transducer_weight = transducer_weight
|
||||
self.fastemit_lambda = fastemit_lambda
|
||||
|
||||
self.auxiliary_ctc_weight = auxiliary_ctc_weight
|
||||
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
decoding_ind: int = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Forward architecture and compute loss(es).
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
loss: Main loss value.
|
||||
stats: Task statistics.
|
||||
weight: Task weights.
|
||||
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
|
||||
# 2. Transducer-related I/O preparation
|
||||
decoder_in, target, t_len, u_len = get_transducer_task_io(
|
||||
text,
|
||||
encoder_out_lens,
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
|
||||
# 3. Decoder
|
||||
self.decoder.set_device(encoder_out.device)
|
||||
decoder_out = self.decoder(decoder_in, u_len)
|
||||
|
||||
# 4. Joint Network
|
||||
joint_out = self.joint_network(
|
||||
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# 5. Losses
|
||||
loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
|
||||
encoder_out,
|
||||
joint_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
loss_ctc, loss_lm = 0.0, 0.0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
loss_ctc = self._calc_ctc_loss(
|
||||
encoder_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
loss_lm = self._calc_lm_loss(decoder_out, target)
|
||||
|
||||
loss = (
|
||||
self.transducer_weight * loss_trans
|
||||
+ self.auxiliary_ctc_weight * loss_ctc
|
||||
+ self.auxiliary_lm_loss_weight * loss_lm
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_transducer=loss_trans.detach(),
|
||||
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
|
||||
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
|
||||
cer_transducer=cer_trans,
|
||||
wer_transducer=wer_trans,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Collect features sequences and features lengths sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
{}: "feats": Features sequences. (B, T, D_feats),
|
||||
"feats_lengths": Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
ind: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encoder speech sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
encoder_out: Encoder outputs. (B, T, D_enc)
|
||||
encoder_out_lens: Encoder outputs lengths. (B,)
|
||||
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract features sequences and features sequences lengths.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
feats: Features sequences. (B, T, D_feats)
|
||||
feats_lengths: Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return feats, feats_lengths
|
||||
|
||||
def _calc_transducer_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
joint_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
|
||||
"""Compute Transducer loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
joint_out: Joint Network output sequences (B, T, U, D_joint)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_transducer: Transducer loss value.
|
||||
cer_transducer: Character error rate for Transducer.
|
||||
wer_transducer: Word Error Rate for Transducer.
|
||||
|
||||
"""
|
||||
if self.criterion_transducer is None:
|
||||
try:
|
||||
# from warprnnt_pytorch import RNNTLoss
|
||||
# self.criterion_transducer = RNNTLoss(
|
||||
# reduction="mean",
|
||||
# fastemit_lambda=self.fastemit_lambda,
|
||||
# )
|
||||
from warp_rnnt import rnnt_loss as RNNTLoss
|
||||
self.criterion_transducer = RNNTLoss
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"warp-rnnt was not installed."
|
||||
"Please consult the installation documentation."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# loss_transducer = self.criterion_transducer(
|
||||
# joint_out,
|
||||
# target,
|
||||
# t_len,
|
||||
# u_len,
|
||||
# )
|
||||
log_probs = torch.log_softmax(joint_out, dim=-1)
|
||||
|
||||
loss_transducer = self.criterion_transducer(
|
||||
log_probs,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
reduction="mean",
|
||||
blank=self.blank_id,
|
||||
gather=True,
|
||||
)
|
||||
|
||||
if not self.training and (self.report_cer or self.report_wer):
|
||||
if self.error_calculator is None:
|
||||
from espnet2.asr_transducer.error_calculator import ErrorCalculator
|
||||
|
||||
self.error_calculator = ErrorCalculator(
|
||||
self.decoder,
|
||||
self.joint_network,
|
||||
self.token_list,
|
||||
self.sym_space,
|
||||
self.sym_blank,
|
||||
report_cer=self.report_cer,
|
||||
report_wer=self.report_wer,
|
||||
)
|
||||
|
||||
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
||||
|
||||
return loss_transducer, cer_transducer, wer_transducer
|
||||
|
||||
return loss_transducer, None, None
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_ctc: CTC loss value.
|
||||
|
||||
"""
|
||||
ctc_in = self.ctc_lin(
|
||||
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
|
||||
)
|
||||
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
|
||||
|
||||
target_mask = target != 0
|
||||
ctc_target = target[target_mask].cpu()
|
||||
|
||||
with torch.backends.cudnn.flags(deterministic=True):
|
||||
loss_ctc = torch.nn.functional.ctc_loss(
|
||||
ctc_in,
|
||||
ctc_target,
|
||||
t_len,
|
||||
u_len,
|
||||
zero_infinity=True,
|
||||
reduction="sum",
|
||||
)
|
||||
loss_ctc /= target.size(0)
|
||||
|
||||
return loss_ctc
|
||||
|
||||
def _calc_lm_loss(
|
||||
self,
|
||||
decoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute LM loss.
|
||||
|
||||
Args:
|
||||
decoder_out: Decoder output sequences. (B, U, D_dec)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Return:
|
||||
loss_lm: LM loss value.
|
||||
|
||||
"""
|
||||
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
|
||||
lm_target = target.view(-1).type(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
true_dist = lm_loss_in.clone()
|
||||
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
|
||||
|
||||
# Ignore blank ID (0)
|
||||
ignore = lm_target == 0
|
||||
lm_target = lm_target.masked_fill(ignore, 0)
|
||||
|
||||
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
|
||||
|
||||
loss_lm = torch.nn.functional.kl_div(
|
||||
torch.log_softmax(lm_loss_in, dim=1),
|
||||
true_dist,
|
||||
reduction="none",
|
||||
)
|
||||
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
|
||||
0
|
||||
)
|
||||
|
||||
return loss_lm
|
||||
588
funasr/models_transducer/espnet_transducer_model_unified.py
Normal file
588
funasr/models_transducer/espnet_transducer_model_unified.py
Normal file
@ -0,0 +1,588 @@
|
||||
"""ESPnet2 ASR Transducer model."""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.models_transducer.utils import get_transducer_task_io
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.losses.label_smoothing_loss import ( # noqa: H301
|
||||
LabelSmoothingLoss,
|
||||
)
|
||||
from funasr.models_transducer.error_calculator import ErrorCalculator
|
||||
if V(torch.__version__) >= V("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
|
||||
"""ESPnet2ASRTransducerModel module definition.
|
||||
|
||||
Args:
|
||||
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
|
||||
token_list: List of token
|
||||
frontend: Frontend module.
|
||||
specaug: SpecAugment module.
|
||||
normalize: Normalization module.
|
||||
encoder: Encoder module.
|
||||
decoder: Decoder module.
|
||||
joint_network: Joint Network module.
|
||||
transducer_weight: Weight of the Transducer loss.
|
||||
fastemit_lambda: FastEmit lambda value.
|
||||
auxiliary_ctc_weight: Weight of auxiliary CTC loss.
|
||||
auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
|
||||
auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
|
||||
auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
|
||||
ignore_id: Initial padding ID.
|
||||
sym_space: Space symbol.
|
||||
sym_blank: Blank Symbol
|
||||
report_cer: Whether to report Character Error Rate during validation.
|
||||
report_wer: Whether to report Word Error Rate during validation.
|
||||
extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: Encoder,
|
||||
decoder: AbsDecoder,
|
||||
att_decoder: Optional[AbsAttDecoder],
|
||||
joint_network: JointNetwork,
|
||||
transducer_weight: float = 1.0,
|
||||
fastemit_lambda: float = 0.0,
|
||||
auxiliary_ctc_weight: float = 0.0,
|
||||
auxiliary_att_weight: float = 0.0,
|
||||
auxiliary_ctc_dropout_rate: float = 0.0,
|
||||
auxiliary_lm_loss_weight: float = 0.0,
|
||||
auxiliary_lm_loss_smoothing: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_sos: str = "<sos/eos>",
|
||||
sym_eos: str = "<sos/eos>",
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ESPnetASRTransducerModel object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
|
||||
if sym_sos in token_list:
|
||||
self.sos = token_list.index(sym_sos)
|
||||
else:
|
||||
self.sos = vocab_size - 1
|
||||
if sym_eos in token_list:
|
||||
self.eos = token_list.index(sym_eos)
|
||||
else:
|
||||
self.eos = vocab_size - 1
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.sym_space = sym_space
|
||||
self.sym_blank = sym_blank
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joint_network = joint_network
|
||||
|
||||
self.criterion_transducer = None
|
||||
self.error_calculator = None
|
||||
|
||||
self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
|
||||
self.use_auxiliary_att = auxiliary_att_weight > 0
|
||||
self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
|
||||
self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
|
||||
|
||||
if self.use_auxiliary_att:
|
||||
self.att_decoder = att_decoder
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
|
||||
self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
|
||||
|
||||
self.transducer_weight = transducer_weight
|
||||
self.fastemit_lambda = fastemit_lambda
|
||||
|
||||
self.auxiliary_ctc_weight = auxiliary_ctc_weight
|
||||
self.auxiliary_att_weight = auxiliary_att_weight
|
||||
self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Forward architecture and compute loss(es).
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
loss: Main loss value.
|
||||
stats: Task statistics.
|
||||
weight: Task weights.
|
||||
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
text = text[:, : text_lengths.max()]
|
||||
#print(speech.shape)
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
loss_att, loss_att_chunk = 0.0, 0.0
|
||||
|
||||
if self.use_auxiliary_att:
|
||||
loss_att, _ = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_att_chunk, _ = self._calc_att_loss(
|
||||
encoder_out_chunk, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 2. Transducer-related I/O preparation
|
||||
decoder_in, target, t_len, u_len = get_transducer_task_io(
|
||||
text,
|
||||
encoder_out_lens,
|
||||
ignore_id=self.ignore_id,
|
||||
)
|
||||
|
||||
# 3. Decoder
|
||||
self.decoder.set_device(encoder_out.device)
|
||||
decoder_out = self.decoder(decoder_in, u_len)
|
||||
|
||||
# 4. Joint Network
|
||||
joint_out = self.joint_network(
|
||||
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
joint_out_chunk = self.joint_network(
|
||||
encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
|
||||
)
|
||||
|
||||
# 5. Losses
|
||||
loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
|
||||
encoder_out,
|
||||
joint_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
|
||||
encoder_out_chunk,
|
||||
joint_out_chunk,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
|
||||
|
||||
if self.use_auxiliary_ctc:
|
||||
loss_ctc = self._calc_ctc_loss(
|
||||
encoder_out,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
loss_ctc_chunk = self._calc_ctc_loss(
|
||||
encoder_out_chunk,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
)
|
||||
|
||||
if self.use_auxiliary_lm_loss:
|
||||
loss_lm = self._calc_lm_loss(decoder_out, target)
|
||||
|
||||
loss_trans = loss_trans_utt + loss_trans_chunk
|
||||
loss_ctc = loss_ctc + loss_ctc_chunk
|
||||
loss_ctc = loss_att + loss_att_chunk
|
||||
|
||||
loss = (
|
||||
self.transducer_weight * loss_trans
|
||||
+ self.auxiliary_ctc_weight * loss_ctc
|
||||
+ self.auxiliary_att_weight * loss_att
|
||||
+ self.auxiliary_lm_loss_weight * loss_lm
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_transducer=loss_trans_utt.detach(),
|
||||
loss_transducer_chunk=loss_trans_chunk.detach(),
|
||||
aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
|
||||
aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
|
||||
aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
|
||||
aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
|
||||
aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
|
||||
cer_transducer=cer_trans,
|
||||
wer_transducer=wer_trans,
|
||||
cer_transducer_chunk=cer_trans_chunk,
|
||||
wer_transducer_chunk=wer_trans_chunk,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Collect features sequences and features lengths sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
text: Label ID sequences. (B, L)
|
||||
text_lengths: Label ID sequences lengths. (B,)
|
||||
kwargs: Contains "utts_id".
|
||||
|
||||
Return:
|
||||
{}: "feats": Features sequences. (B, T, D_feats),
|
||||
"feats_lengths": Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encoder speech sequences.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
encoder_out: Encoder outputs. (B, T, D_enc)
|
||||
encoder_out_lens: Encoder outputs lengths. (B,)
|
||||
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_chunk, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract features sequences and features sequences lengths.
|
||||
|
||||
Args:
|
||||
speech: Speech sequences. (B, S)
|
||||
speech_lengths: Speech sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
feats: Features sequences. (B, T, D_feats)
|
||||
feats_lengths: Features sequences lengths. (B,)
|
||||
|
||||
"""
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
|
||||
return feats, feats_lengths
|
||||
|
||||
def _calc_transducer_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
joint_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
|
||||
"""Compute Transducer loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
joint_out: Joint Network output sequences (B, T, U, D_joint)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_transducer: Transducer loss value.
|
||||
cer_transducer: Character error rate for Transducer.
|
||||
wer_transducer: Word Error Rate for Transducer.
|
||||
|
||||
"""
|
||||
if self.criterion_transducer is None:
|
||||
try:
|
||||
# from warprnnt_pytorch import RNNTLoss
|
||||
# self.criterion_transducer = RNNTLoss(
|
||||
# reduction="mean",
|
||||
# fastemit_lambda=self.fastemit_lambda,
|
||||
# )
|
||||
from warp_rnnt import rnnt_loss as RNNTLoss
|
||||
self.criterion_transducer = RNNTLoss
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"warp-rnnt was not installed."
|
||||
"Please consult the installation documentation."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# loss_transducer = self.criterion_transducer(
|
||||
# joint_out,
|
||||
# target,
|
||||
# t_len,
|
||||
# u_len,
|
||||
# )
|
||||
log_probs = torch.log_softmax(joint_out, dim=-1)
|
||||
|
||||
loss_transducer = self.criterion_transducer(
|
||||
log_probs,
|
||||
target,
|
||||
t_len,
|
||||
u_len,
|
||||
reduction="mean",
|
||||
blank=self.blank_id,
|
||||
fastemit_lambda=self.fastemit_lambda,
|
||||
gather=True,
|
||||
)
|
||||
|
||||
if not self.training and (self.report_cer or self.report_wer):
|
||||
if self.error_calculator is None:
|
||||
self.error_calculator = ErrorCalculator(
|
||||
self.decoder,
|
||||
self.joint_network,
|
||||
self.token_list,
|
||||
self.sym_space,
|
||||
self.sym_blank,
|
||||
report_cer=self.report_cer,
|
||||
report_wer=self.report_wer,
|
||||
)
|
||||
|
||||
cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
||||
|
||||
return loss_transducer, cer_transducer, wer_transducer
|
||||
|
||||
return loss_transducer, None, None
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
t_len: torch.Tensor,
|
||||
u_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
|
||||
Args:
|
||||
encoder_out: Encoder output sequences. (B, T, D_enc)
|
||||
target: Target label ID sequences. (B, L)
|
||||
t_len: Encoder output sequences lengths. (B,)
|
||||
u_len: Target label ID sequences lengths. (B,)
|
||||
|
||||
Return:
|
||||
loss_ctc: CTC loss value.
|
||||
|
||||
"""
|
||||
ctc_in = self.ctc_lin(
|
||||
torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
|
||||
)
|
||||
ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
|
||||
|
||||
target_mask = target != 0
|
||||
ctc_target = target[target_mask].cpu()
|
||||
|
||||
with torch.backends.cudnn.flags(deterministic=True):
|
||||
loss_ctc = torch.nn.functional.ctc_loss(
|
||||
ctc_in,
|
||||
ctc_target,
|
||||
t_len,
|
||||
u_len,
|
||||
zero_infinity=True,
|
||||
reduction="sum",
|
||||
)
|
||||
loss_ctc /= target.size(0)
|
||||
|
||||
return loss_ctc
|
||||
|
||||
def _calc_lm_loss(
|
||||
self,
|
||||
decoder_out: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute LM loss.
|
||||
|
||||
Args:
|
||||
decoder_out: Decoder output sequences. (B, U, D_dec)
|
||||
target: Target label ID sequences. (B, L)
|
||||
|
||||
Return:
|
||||
loss_lm: LM loss value.
|
||||
|
||||
"""
|
||||
lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
|
||||
lm_target = target.view(-1).type(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
true_dist = lm_loss_in.clone()
|
||||
true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
|
||||
|
||||
# Ignore blank ID (0)
|
||||
ignore = lm_target == 0
|
||||
lm_target = lm_target.masked_fill(ignore, 0)
|
||||
|
||||
true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
|
||||
|
||||
loss_lm = torch.nn.functional.kl_div(
|
||||
torch.log_softmax(lm_loss_in, dim=1),
|
||||
true_dist,
|
||||
reduction="none",
|
||||
)
|
||||
loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
|
||||
0
|
||||
)
|
||||
|
||||
return loss_lm
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
|
||||
ys_pad = torch.cat(
|
||||
[
|
||||
self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
|
||||
ys_pad,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
ys_pad_lens += 1
|
||||
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.att_decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
return loss_att, acc_att
|
||||
62
funasr/models_transducer/joint_network.py
Normal file
62
funasr/models_transducer/joint_network.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Transducer joint network implementation."""
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.models_transducer.activation import get_activation
|
||||
|
||||
|
||||
class JointNetwork(torch.nn.Module):
|
||||
"""Transducer joint network module.
|
||||
|
||||
Args:
|
||||
output_size: Output size.
|
||||
encoder_size: Encoder output size.
|
||||
decoder_size: Decoder output size..
|
||||
joint_space_size: Joint space size.
|
||||
joint_act_type: Type of activation for joint network.
|
||||
**activation_parameters: Parameters for the activation function.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: int,
|
||||
encoder_size: int,
|
||||
decoder_size: int,
|
||||
joint_space_size: int = 256,
|
||||
joint_activation_type: str = "tanh",
|
||||
**activation_parameters,
|
||||
) -> None:
|
||||
"""Construct a JointNetwork object."""
|
||||
super().__init__()
|
||||
|
||||
self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
|
||||
self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
|
||||
|
||||
self.lin_out = torch.nn.Linear(joint_space_size, output_size)
|
||||
|
||||
self.joint_activation = get_activation(
|
||||
joint_activation_type, **activation_parameters
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_out: torch.Tensor,
|
||||
dec_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Joint computation of encoder and decoder hidden state sequences.
|
||||
|
||||
Args:
|
||||
enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
|
||||
dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
|
||||
|
||||
Returns:
|
||||
joint_out: Joint output state sequences. (B, T, U, D_out)
|
||||
|
||||
"""
|
||||
if project_input:
|
||||
joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
|
||||
else:
|
||||
joint_out = self.joint_activation(enc_out + dec_out)
|
||||
return self.lin_out(joint_out)
|
||||
200
funasr/models_transducer/utils.py
Normal file
200
funasr/models_transducer/utils.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Utility functions for Transducer models."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TooShortUttError(Exception):
|
||||
"""Raised when the utt is too short for subsampling.
|
||||
|
||||
Args:
|
||||
message: Error message to display.
|
||||
actual_size: The size that cannot pass the subsampling.
|
||||
limit: The size limit for subsampling.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, actual_size: int, limit: int) -> None:
|
||||
"""Construct a TooShortUttError module."""
|
||||
super().__init__(message)
|
||||
|
||||
self.actual_size = actual_size
|
||||
self.limit = limit
|
||||
|
||||
|
||||
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
|
||||
"""Check if the input is too short for subsampling.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor for Conv2DSubsampling.
|
||||
size: Input size.
|
||||
|
||||
Returns:
|
||||
: Whether an error should be sent.
|
||||
: Size limit for specified subsampling factor.
|
||||
|
||||
"""
|
||||
if sub_factor == 2 and size < 3:
|
||||
return True, 7
|
||||
elif sub_factor == 4 and size < 7:
|
||||
return True, 7
|
||||
elif sub_factor == 6 and size < 11:
|
||||
return True, 11
|
||||
|
||||
return False, -1
|
||||
|
||||
|
||||
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
|
||||
"""Get conv2D second layer parameters for given subsampling factor.
|
||||
|
||||
Args:
|
||||
sub_factor: Subsampling factor (1/X).
|
||||
input_size: Input size.
|
||||
|
||||
Returns:
|
||||
: Kernel size for second convolution.
|
||||
: Stride for second convolution.
|
||||
: Conv2DSubsampling output size.
|
||||
|
||||
"""
|
||||
if sub_factor == 2:
|
||||
return 3, 1, (((input_size - 1) // 2 - 2))
|
||||
elif sub_factor == 4:
|
||||
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
|
||||
elif sub_factor == 6:
|
||||
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
|
||||
else:
|
||||
raise ValueError(
|
||||
"subsampling_factor parameter should be set to either 2, 4 or 6."
|
||||
)
|
||||
|
||||
|
||||
def make_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
left_chunk_size: int = 0,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
"""Create chunk mask for the subsequent steps (size, size).
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
size: Size of the source mask.
|
||||
chunk_size: Number of frames in chunk.
|
||||
left_chunk_size: Size of the left context in chunks (0 means full context).
|
||||
device: Device for the mask tensor.
|
||||
|
||||
Returns:
|
||||
mask: Chunk mask. (size, size)
|
||||
|
||||
"""
|
||||
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
|
||||
for i in range(size):
|
||||
if left_chunk_size <= 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
||||
|
||||
end = min((i // chunk_size + 1) * chunk_size, size)
|
||||
mask[i, start:end] = True
|
||||
|
||||
return ~mask
|
||||
|
||||
|
||||
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create source mask for given lengths.
|
||||
|
||||
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
||||
|
||||
Args:
|
||||
lengths: Sequence lengths. (B,)
|
||||
|
||||
Returns:
|
||||
: Mask for the sequence lengths. (B, max_len)
|
||||
|
||||
"""
|
||||
max_len = lengths.max()
|
||||
batch_size = lengths.size(0)
|
||||
|
||||
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
|
||||
|
||||
return expanded_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
def get_transducer_task_io(
|
||||
labels: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Get Transducer loss I/O.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
encoder_out_lens: Encoder output lengths. (B,)
|
||||
ignore_id: Padding symbol ID.
|
||||
blank_id: Blank symbol ID.
|
||||
|
||||
Returns:
|
||||
decoder_in: Decoder inputs. (B, U)
|
||||
target: Target label ID sequences. (B, U)
|
||||
t_len: Time lengths. (B,)
|
||||
u_len: Label lengths. (B,)
|
||||
|
||||
"""
|
||||
|
||||
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
|
||||
"""Create padded batch of labels from a list of labels sequences.
|
||||
|
||||
Args:
|
||||
labels: Labels sequences. [B x (?)]
|
||||
padding_value: Padding value.
|
||||
|
||||
Returns:
|
||||
labels: Batch of padded labels sequences. (B,)
|
||||
|
||||
"""
|
||||
batch_size = len(labels)
|
||||
|
||||
padded = (
|
||||
labels[0]
|
||||
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
|
||||
.fill_(padding_value)
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
padded[i, : labels[i].size(0)] = labels[i]
|
||||
|
||||
return padded
|
||||
|
||||
device = labels.device
|
||||
|
||||
labels_unpad = [y[y != ignore_id] for y in labels]
|
||||
blank = labels[0].new([blank_id])
|
||||
|
||||
decoder_in = pad_list(
|
||||
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
|
||||
).to(device)
|
||||
|
||||
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
|
||||
|
||||
encoder_out_lens = list(map(int, encoder_out_lens))
|
||||
t_len = torch.IntTensor(encoder_out_lens).to(device)
|
||||
|
||||
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
|
||||
|
||||
return decoder_in, target, t_len, u_len
|
||||
|
||||
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
|
||||
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
|
||||
if t.size(dim) == pad_len:
|
||||
return t
|
||||
else:
|
||||
pad_size = list(t.shape)
|
||||
pad_size[dim] = pad_len - t.size(dim)
|
||||
return torch.cat(
|
||||
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
|
||||
)
|
||||
487
funasr/tasks/asr_transducer.py
Normal file
487
funasr/tasks/asr_transducer.py
Normal file
@ -0,0 +1,487 @@
|
||||
"""ASR Transducer Task."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Callable, Collection, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types, check_return_type
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
DynamicConvolution2DTransformerDecoder,
|
||||
DynamicConvolutionTransformerDecoder,
|
||||
LightweightConvolution2DTransformerDecoder,
|
||||
LightweightConvolutionTransformerDecoder,
|
||||
TransformerDecoder,
|
||||
)
|
||||
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
|
||||
from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
|
||||
from funasr.models_transducer.encoder.encoder import Encoder
|
||||
from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
|
||||
from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
|
||||
from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
|
||||
from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
|
||||
from funasr.models_transducer.joint_network import JointNetwork
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.train.trainer import Trainer
|
||||
from funasr.utils.get_default_kwargs import get_default_kwargs
|
||||
from funasr.utils.nested_dict_action import NestedDictAction
|
||||
from funasr.utils.types import float_or_none, int_or_none, str2bool, str_or_none
|
||||
|
||||
frontend_choices = ClassChoices(
|
||||
name="frontend",
|
||||
classes=dict(
|
||||
default=DefaultFrontend,
|
||||
sliding_window=SlidingWindow,
|
||||
),
|
||||
type_check=AbsFrontend,
|
||||
default="default",
|
||||
)
|
||||
specaug_choices = ClassChoices(
|
||||
"specaug",
|
||||
classes=dict(
|
||||
specaug=SpecAug,
|
||||
),
|
||||
type_check=AbsSpecAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
global_mvn=GlobalMVN,
|
||||
utterance_mvn=UtteranceMVN,
|
||||
),
|
||||
type_check=AbsNormalize,
|
||||
default="utterance_mvn",
|
||||
optional=True,
|
||||
)
|
||||
encoder_choices = ClassChoices(
|
||||
"encoder",
|
||||
classes=dict(
|
||||
encoder=Encoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
),
|
||||
default="encoder",
|
||||
)
|
||||
|
||||
decoder_choices = ClassChoices(
|
||||
"decoder",
|
||||
classes=dict(
|
||||
rnn=RNNDecoder,
|
||||
stateless=StatelessDecoder,
|
||||
),
|
||||
type_check=AbsDecoder,
|
||||
default="rnn",
|
||||
)
|
||||
|
||||
att_decoder_choices = ClassChoices(
|
||||
"att_decoder",
|
||||
classes=dict(
|
||||
transformer=TransformerDecoder,
|
||||
lightweight_conv=LightweightConvolutionTransformerDecoder,
|
||||
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
|
||||
dynamic_conv=DynamicConvolutionTransformerDecoder,
|
||||
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
|
||||
),
|
||||
type_check=AbsAttDecoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
class ASRTransducerTask(AbsTask):
|
||||
"""ASR Transducer Task definition."""
|
||||
|
||||
num_optimizers: int = 1
|
||||
|
||||
class_choices_list = [
|
||||
frontend_choices,
|
||||
specaug_choices,
|
||||
normalize_choices,
|
||||
encoder_choices,
|
||||
decoder_choices,
|
||||
att_decoder_choices,
|
||||
]
|
||||
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
"""Add Transducer task arguments.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
parser: Transducer arguments parser.
|
||||
"""
|
||||
group = parser.add_argument_group(description="Task related.")
|
||||
|
||||
# required = parser.get_default("required")
|
||||
# required += ["token_list"]
|
||||
|
||||
group.add_argument(
|
||||
"--token_list",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="Integer-string mapper for tokens.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of dimensions for input features.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--init",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="Type of model initialization to use.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_conf",
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(ESPnetASRTransducerModel),
|
||||
help="The keyword arguments for the model class.",
|
||||
)
|
||||
# group.add_argument(
|
||||
# "--encoder_conf",
|
||||
# action=NestedDictAction,
|
||||
# default={},
|
||||
# help="The keyword arguments for the encoder class.",
|
||||
# )
|
||||
group.add_argument(
|
||||
"--joint_network_conf",
|
||||
action=NestedDictAction,
|
||||
default={},
|
||||
help="The keyword arguments for the joint network class.",
|
||||
)
|
||||
group = parser.add_argument_group(description="Preprocess related.")
|
||||
group.add_argument(
|
||||
"--use_preprocessor",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to apply preprocessing to input data.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str,
|
||||
default="bpe",
|
||||
choices=["bpe", "char", "word", "phn"],
|
||||
help="The type of tokens to use during tokenization.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The path of the sentencepiece model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_linguistic_symbols",
|
||||
type=str_or_none,
|
||||
help="The 'non_linguistic_symbols' file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cleaner",
|
||||
type=str_or_none,
|
||||
choices=[None, "tacotron", "jaconv", "vietnamese"],
|
||||
default=None,
|
||||
help="Text cleaner to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g2p",
|
||||
type=str_or_none,
|
||||
choices=g2p_choices,
|
||||
default=None,
|
||||
help="g2p method to use if --token_type=phn.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speech_volume_normalize",
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help="Normalization value for maximum amplitude scaling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The RIR SCP file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The probability of the applied RIR convolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The path of noise SCP file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The probability of the applied noise addition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_db_range",
|
||||
type=str,
|
||||
default="13_15",
|
||||
help="The range of the noise decibel level.",
|
||||
)
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --decoder and --decoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[
|
||||
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||
]:
|
||||
"""Build collate function.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
args: Task arguments.
|
||||
train: Training mode.
|
||||
Return:
|
||||
: Callable collate function.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
"""Build pre-processing function.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
args: Task arguments.
|
||||
train: Training mode.
|
||||
Return:
|
||||
: Callable pre-processing function.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
token_type=args.token_type,
|
||||
token_list=args.token_list,
|
||||
bpemodel=args.bpemodel,
|
||||
non_linguistic_symbols=args.non_linguistic_symbols,
|
||||
text_cleaner=args.cleaner,
|
||||
g2p_type=args.g2p,
|
||||
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||
rir_apply_prob=args.rir_apply_prob
|
||||
if hasattr(args, "rir_apply_prob")
|
||||
else 1.0,
|
||||
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||
noise_apply_prob=args.noise_apply_prob
|
||||
if hasattr(args, "noise_apply_prob")
|
||||
else 1.0,
|
||||
noise_db_range=args.noise_db_range
|
||||
if hasattr(args, "noise_db_range")
|
||||
else "13_15",
|
||||
speech_volume_normalize=args.speech_volume_normalize
|
||||
if hasattr(args, "rir_scp")
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def required_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
"""Required data depending on task mode.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
train: Training mode.
|
||||
inference: Inference mode.
|
||||
Return:
|
||||
retval: Required task data.
|
||||
"""
|
||||
if not inference:
|
||||
retval = ("speech", "text")
|
||||
else:
|
||||
retval = ("speech",)
|
||||
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
"""Optional data depending on task mode.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
train: Training mode.
|
||||
inference: Inference mode.
|
||||
Return:
|
||||
retval: Optional task data.
|
||||
"""
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
|
||||
"""Required data depending on task mode.
|
||||
Args:
|
||||
cls: ASRTransducerTask object.
|
||||
args: Task arguments.
|
||||
Return:
|
||||
model: ASR Transducer model.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size }")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Encoder
|
||||
|
||||
if getattr(args, "encoder", None) is not None:
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size, **args.encoder_conf)
|
||||
else:
|
||||
encoder = Encoder(input_size, **args.encoder_conf)
|
||||
encoder_output_size = encoder.output_size
|
||||
|
||||
# 5. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
decoder_output_size = decoder.output_size
|
||||
|
||||
if getattr(args, "att_decoder", None) is not None:
|
||||
att_decoder_class = att_decoder_choices.get_class(args.att_decoder)
|
||||
|
||||
att_decoder = att_decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.att_decoder_conf,
|
||||
)
|
||||
else:
|
||||
att_decoder = None
|
||||
|
||||
# 6. Joint Network
|
||||
joint_network = JointNetwork(
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
decoder_output_size,
|
||||
**args.joint_network_conf,
|
||||
)
|
||||
|
||||
# 7. Build model
|
||||
|
||||
if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
|
||||
model = UniASRTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
elif encoder.unified_model_training:
|
||||
model = ESPnetASRUnifiedTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
else:
|
||||
model = ESPnetASRTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 8. Initialize model
|
||||
if args.init is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently not supported.",
|
||||
"Initialization part will be reworked in a short future.",
|
||||
)
|
||||
|
||||
#assert check_return_type(model)
|
||||
|
||||
return model
|
||||
Loading…
Reference in New Issue
Block a user