diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py new file mode 100644 index 000000000..61b63ec93 --- /dev/null +++ b/funasr/bin/punc_train.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +import os +from funasr.tasks.punctuation import PunctuationTask + + +def parse_args(): + parser = PunctuationTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + parser.add_argument( + "--punc_list", + type=str, + default=None, + help="Punctuation list", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + """ + punc training. + """ + PunctuationTask.main(args=args, cmd=cmd) + + +if __name__ == "__main__": + args = parse_args() + + # setup local gpu_id + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + + main(args=args) diff --git a/funasr/bin/punc_train_vadrealtime.py b/funasr/bin/punc_train_vadrealtime.py new file mode 100644 index 000000000..c5afaad80 --- /dev/null +++ b/funasr/bin/punc_train_vadrealtime.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +import os +from funasr.tasks.punctuation import PunctuationTask + + +def parse_args(): + parser = PunctuationTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + parser.add_argument( + "--punc_list", + type=str, + default=None, + help="Punctuation list", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + """ + punc training. + """ + PunctuationTask.main(args=args, cmd=cmd) + + +if __name__ == "__main__": + args = parse_args() + + # setup local gpu_id + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + main(args=args) diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py index 8f7fd0b35..093ad608e 100644 --- a/funasr/datasets/large_datasets/build_dataloader.py +++ b/funasr/datasets/large_datasets/build_dataloader.py @@ -34,16 +34,20 @@ def load_seg_dict(seg_dict_file): return seg_dict class ArkDataLoader(AbsIterFactory): - def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"): + def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, punc_dict_file=None, mode="train"): symbol_table = read_symbol_table(dict_file) if dict_file is not None else None if seg_dict_file is not None: seg_dict = load_seg_dict(seg_dict_file) else: seg_dict = None + if punc_dict_file is not None: + punc_dict = read_symbol_table(punc_dict_file) + else: + punc_dict = None self.dataset_conf = dataset_conf logging.info("dataloader config: {}".format(self.dataset_conf)) batch_mode = self.dataset_conf.get("batch_mode", "padding") - self.dataset = Dataset(data_list, symbol_table, seg_dict, + self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, self.dataset_conf, mode=mode, batch_mode=batch_mode) def build_iter(self, epoch, shuffle=True): diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py index 212373720..61231d24c 100644 --- a/funasr/datasets/large_datasets/dataset.py +++ b/funasr/datasets/large_datasets/dataset.py @@ -127,14 +127,17 @@ class AudioDataset(IterableDataset): sample_dict["key"] = key else: text = item - sample_dict[data_name] = text.strip().split()[1:] + segs = text.strip().split() + sample_dict[data_name] = segs[1:] + if "key" not in sample_dict: + sample_dict["key"] = segs[0] yield sample_dict self.close_reader(reader_list) def len_fn_example(data): - return len(data) + return 1 def len_fn_token(data): @@ -148,6 +151,7 @@ def len_fn_token(data): def Dataset(data_list_file, dict, seg_dict, + punc_dict, conf, mode="train", batch_mode="padding"): @@ -162,7 +166,7 @@ def Dataset(data_list_file, dataset = FilterIterDataPipe(dataset, fn=filter_fn) if "text" in data_names: - vocab = {'vocab': dict, 'seg_dict': seg_dict} + vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict} tokenize_fn = partial(tokenize, **vocab) dataset = MapperIterDataPipe(dataset, fn=tokenize_fn) @@ -191,6 +195,10 @@ def Dataset(data_list_file, sort_size=sort_size, batch_mode=batch_mode) - dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping) + int_pad_value = conf.get("int_pad_value", -1) + float_pad_value = conf.get("float_pad_value", 0.0) + padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value} + padding_fn = partial(padding, **padding_conf) + dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping) return dataset diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py index e814b1cd3..e0feac62e 100644 --- a/funasr/datasets/large_datasets/utils/padding.py +++ b/funasr/datasets/large_datasets/utils/padding.py @@ -6,9 +6,8 @@ from torch.nn.utils.rnn import pad_sequence def padding(data, float_pad_value=0.0, int_pad_value=-1): assert isinstance(data, list) assert "key" in data[0] - assert "speech" in data[0] - assert "text" in data[0] - + assert "speech" in data[0] or "text" in data[0] + keys = [x["key"] for x in data] batch = {} diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py index 0c0188564..caeb42626 100644 --- a/funasr/datasets/large_datasets/utils/tokenize.py +++ b/funasr/datasets/large_datasets/utils/tokenize.py @@ -31,22 +31,43 @@ def seg_tokenize(txt, seg_dict): def tokenize(data, vocab=None, - seg_dict=None): + seg_dict=None, + punc_dict=None): assert "text" in data assert isinstance(vocab, dict) text = data["text"] token = [] + vad = -2 if seg_dict is not None: assert isinstance(seg_dict, dict) txt = forward_segment("".join(text).lower(), seg_dict) text = seg_tokenize(txt, seg_dict) - - for x in text: - if x in vocab: + + length = len(text) + for i in range(length): + x = text[i] + if i == length-1 and "punc" in data and text[i].startswith("vad:"): + vad = x[-1][4:] + if len(vad) == 0: + vad = -1 + else: + vad = int(vad) + elif x in vocab: token.append(vocab[x]) else: token.append(vocab['']) + if "punc" in data and punc_dict is not None: + punc_token = [] + for punc in data["punc"]: + if punc in punc_dict: + punc_token.append(punc_dict[punc]) + else: + punc_token.append(punc_dict["_"]) + data["punc"] = np.array(punc_token) + data["text"] = np.array(token) + if vad is not -2: + data["vad_indexes"]=np.array([vad], dtype=np.int64) return data diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index 8e8679475..20a3791de 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -704,3 +704,103 @@ class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor): del data[self.split_text_name] return result +class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor): + def __init__( + self, + train: bool, + token_type: List[str] = [None], + token_list: List[Union[Path, str, Iterable[str]]] = [None], + bpemodel: List[Union[Path, str, Iterable[str]]] = [None], + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: List[str] = ["text"], + vad_name: str = "vad_indexes", + ): + # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor + super().__init__( + train=train, + token_type=token_type[0], + token_list=token_list[0], + bpemodel=bpemodel[0], + text_cleaner=text_cleaner, + g2p_type=g2p_type, + unk_symbol=unk_symbol, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + delimiter=delimiter, + speech_name=speech_name, + text_name=text_name[0], + rir_scp=rir_scp, + rir_apply_prob=rir_apply_prob, + noise_scp=noise_scp, + noise_apply_prob=noise_apply_prob, + noise_db_range=noise_db_range, + speech_volume_normalize=speech_volume_normalize, + ) + + assert ( + len(token_type) == len(token_list) == len(bpemodel) == len(text_name) + ), "token_type, token_list, bpemodel, or processing text_name mismatched" + self.num_tokenizer = len(token_type) + self.tokenizer = [] + self.token_id_converter = [] + + for i in range(self.num_tokenizer): + if token_type[i] is not None: + if token_list[i] is None: + raise ValueError("token_list is required if token_type is not None") + + self.tokenizer.append( + build_tokenizer( + token_type=token_type[i], + bpemodel=bpemodel[i], + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + g2p_type=g2p_type, + ) + ) + self.token_id_converter.append( + TokenIDConverter( + token_list=token_list[i], + unk_symbol=unk_symbol, + ) + ) + else: + self.tokenizer.append(None) + self.token_id_converter.append(None) + + self.text_cleaner = TextCleaner(text_cleaner) + self.text_name = text_name # override the text_name from CommonPreprocessor + self.vad_name = vad_name + + def _text_process( + self, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + for i in range(self.num_tokenizer): + text_name = self.text_name[i] + if text_name in data and self.tokenizer[i] is not None: + text = data[text_name] + text = self.text_cleaner(text) + tokens = self.tokenizer[i].text2tokens(text) + if "vad:" in tokens[-1]: + vad = tokens[-1][4:] + tokens = tokens[:-1] + if len(vad) == 0: + vad = -1 + else: + vad = int(vad) + data[self.vad_name] = np.array([vad], dtype=np.int64) + text_ints = self.token_id_converter[i].tokens2ids(tokens) + data[text_name] = np.array(text_ints, dtype=np.int64) diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index c47d96d06..627700524 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -439,6 +439,18 @@ class MultiHeadedAttentionSANM(nn.Module): att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) return att_outs + fsmn_memory +class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + q_h, k_h, v_h, v = self.forward_qkv(x) + fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) + return att_outs + fsmn_memory + class MultiHeadedAttentionSANMDecoder(nn.Module): """Multi-Head Attention layer. diff --git a/funasr/modules/mask.py b/funasr/modules/mask.py index 8f068e11c..a8c168bc1 100644 --- a/funasr/modules/mask.py +++ b/funasr/modules/mask.py @@ -33,3 +33,20 @@ def target_mask(ys_in_pad, ignore_id): ys_mask = ys_in_pad != ignore_id m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m + +def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool): + """Create mask for decoder self-attention. + + :param int size: size of mask + :param int vad_pos: index of vad index + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor (B, Lmax, Lmax) + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + if vad_pos <= 0 or vad_pos >= size: + return ret + sub_corner = torch.zeros( + vad_pos - 1, size - vad_pos, device=device, dtype=dtype) + ret[0:vad_pos - 1, vad_pos:] = sub_corner + return ret diff --git a/funasr/punctuation/abs_model.py b/funasr/punctuation/abs_model.py index 5f6afb749..404d5e893 100644 --- a/funasr/punctuation/abs_model.py +++ b/funasr/punctuation/abs_model.py @@ -25,3 +25,7 @@ class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC): @abstractmethod def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError + + @abstractmethod + def with_vad(self) -> bool: + raise NotImplementedError diff --git a/funasr/punctuation/espnet_model.py b/funasr/punctuation/espnet_model.py index 65edaad08..c513779d8 100644 --- a/funasr/punctuation/espnet_model.py +++ b/funasr/punctuation/espnet_model.py @@ -14,15 +14,18 @@ from funasr.train.abs_espnet_model import AbsESPnetModel class ESPnetPunctuationModel(AbsESPnetModel): - def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0): + def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): assert check_argument_types() super().__init__() self.punc_model = punc_model + self.punc_weight = torch.Tensor(punc_weight) self.sos = 1 self.eos = 2 # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. self.ignore_id = ignore_id + if self.punc_model.with_vad(): + print("This is a vad puncuation model.") def nll( self, @@ -31,6 +34,8 @@ class ESPnetPunctuationModel(AbsESPnetModel): text_lengths: torch.Tensor, punc_lengths: torch.Tensor, max_length: Optional[int] = None, + vad_indexes: Optional[torch.Tensor] = None, + vad_indexes_lengths: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute negative log likelihood(nll) @@ -49,19 +54,16 @@ class ESPnetPunctuationModel(AbsESPnetModel): else: text = text[:, :max_length] punc = punc[:, :max_length] - # 1. Create a sentence pair like ' w1 w2 w3' and 'w1 w2 w3 ' - # text: (Batch, Length) -> x, y: (Batch, Length + 1) - #x = F.pad(text, [1, 0], "constant", self.eos) - #t = F.pad(text, [0, 1], "constant", self.ignore_id) - #for i, l in enumerate(text_lengths): - # t[i, l] = self.sos - #x_lengths = text_lengths + 1 + + if self.punc_model.with_vad(): + # Should be VadRealtimeTransformer + assert vad_indexes is not None + y, _ = self.punc_model(text, text_lengths, vad_indexes) + else: + # Should be TargetDelayTransformer, + y, _ = self.punc_model(text, text_lengths) - # 2. Forward Language model - # x: (Batch, Length) -> y: (Batch, Length, NVocab) - y, _ = self.punc_model(text, text_lengths) - - # 3. Calc negative log likelihood + # Calc negative log likelihood # nll: (BxL,) if self.training == False: _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) @@ -72,7 +74,8 @@ class ESPnetPunctuationModel(AbsESPnetModel): nll = torch.Tensor([f1_score]).repeat(text_lengths.sum()) return nll, text_lengths else: - nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), reduction="none", ignore_index=self.ignore_id) + self.punc_weight = self.punc_weight.to(punc.device) + nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) # nll: (BxL,) -> (BxL,) if max_length is None: nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) @@ -130,9 +133,16 @@ class ESPnetPunctuationModel(AbsESPnetModel): assert x_lengths.size(0) == total_num return nll, x_lengths - def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor, - punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths) + def forward( + self, + text: torch.Tensor, + punc: torch.Tensor, + text_lengths: torch.Tensor, + punc_lengths: torch.Tensor, + vad_indexes: Optional[torch.Tensor] = None, + vad_indexes_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes) ntokens = y_lengths.sum() loss = nll.sum() / ntokens stats = dict(loss=loss.detach()) @@ -145,5 +155,12 @@ class ESPnetPunctuationModel(AbsESPnetModel): text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: return {} - def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: - return self.punc_model(text, text_lengths) + def inference(self, + text: torch.Tensor, + text_lengths: torch.Tensor, + vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]: + if self.punc_model.with_vad(): + assert vad_indexes is not None + return self.punc_model(text, text_lengths, vad_indexes) + else: + return self.punc_model(text, text_lengths) diff --git a/funasr/punctuation/sanm_encoder.py b/funasr/punctuation/sanm_encoder.py new file mode 100644 index 000000000..896209323 --- /dev/null +++ b/funasr/punctuation/sanm_encoder.py @@ -0,0 +1,590 @@ +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import logging +import torch +import torch.nn as nn +from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk +from typeguard import check_argument_types +import numpy as np +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask +from funasr.modules.embedding import SinusoidalPositionEncoder +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.multi_layer_conv import Conv1dLinear +from funasr.modules.multi_layer_conv import MultiLayeredConv1d +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.modules.repeat import repeat +from funasr.modules.subsampling import Conv2dSubsampling +from funasr.modules.subsampling import Conv2dSubsampling2 +from funasr.modules.subsampling import Conv2dSubsampling6 +from funasr.modules.subsampling import Conv2dSubsampling8 +from funasr.modules.subsampling import TooShortUttError +from funasr.modules.subsampling import check_short_utt +from funasr.models.ctc import CTC +from funasr.models.encoder.abs_encoder import AbsEncoder + +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.mask import subsequent_mask, vad_mask + +class EncoderLayerSANM(nn.Module): + def __init__( + self, + in_size, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayerSANM, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(in_size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.in_size = in_size + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + self.dropout_rate = dropout_rate + + def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = stoch_layer_coeff * self.concat_linear(x_concat) + else: + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) + ) + else: + x = stoch_layer_coeff * self.dropout( + self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder) + ) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + + return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder + +class SANMEncoder(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size : int = 11, + sanm_shfit : int = 0, + selfattention_layer_type: str = "sanm", + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + SinusoidalPositionEncoder(), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + elif input_layer == "pe": + self.embed = SinusoidalPositionEncoder() + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + + elif selfattention_layer_type == "sanm": + self.encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + self.encoders0 = repeat( + 1, + lambda lnum: EncoderLayerSANM( + input_size, + output_size, + self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.encoders = repeat( + num_blocks-1, + lambda lnum: EncoderLayerSANM( + output_size, + output_size, + self.encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + self.dropout = nn.Dropout(dropout_rate) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad *= self.output_size()**0.5 + if self.embed is None: + xs_pad = xs_pad + elif ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + # xs_pad = self.dropout(xs_pad) + encoder_outs = self.encoders0(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + encoder_outs = self.encoders(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None + +class SANMVadEncoder(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + kernel_size : int = 11, + sanm_shfit : int = 0, + selfattention_layer_type: str = "sanm", + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + SinusoidalPositionEncoder(), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + elif input_layer == "pe": + self.embed = SinusoidalPositionEncoder() + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + + elif selfattention_layer_type == "sanm": + self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + self.encoders0 = repeat( + 1, + lambda lnum: EncoderLayerSANM( + input_size, + output_size, + self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + self.encoders = repeat( + num_blocks-1, + lambda lnum: EncoderLayerSANM( + output_size, + output_size, + self.encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + self.dropout = nn.Dropout(dropout_rate) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + vad_indexes: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) + no_future_masks = masks & sub_masks + xs_pad *= self.output_size()**0.5 + if self.embed is None: + xs_pad = xs_pad + elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + # xs_pad = self.dropout(xs_pad) + mask_tup0 = [masks, no_future_masks] + encoder_outs = self.encoders0(xs_pad, mask_tup0) + xs_pad, _ = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + #if len(self.interctc_layer_idx) == 0: + if False: + # Here, we should not use the repeat operation to do it for all layers. + encoder_outs = self.encoders(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + if layer_idx + 1 == len(self.encoders): + # This is last layer. + coner_mask = torch.ones(masks.size(0), + masks.size(-1), + masks.size(-1), + device=xs_pad.device, + dtype=torch.bool) + for word_index, length in enumerate(ilens): + coner_mask[word_index, :, :] = vad_mask(masks.size(-1), + vad_indexes[word_index], + device=xs_pad.device) + layer_mask = masks & coner_mask + else: + layer_mask = no_future_masks + mask_tup1 = [masks, layer_mask] + encoder_outs = encoder_layer(xs_pad, mask_tup1) + xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None + diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/punctuation/target_delay_transformer.py index 10cc5a845..219af263f 100644 --- a/funasr/punctuation/target_delay_transformer.py +++ b/funasr/punctuation/target_delay_transformer.py @@ -8,7 +8,7 @@ import torch.nn as nn from funasr.modules.embedding import PositionalEncoding from funasr.modules.embedding import SinusoidalPositionEncoder #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder -from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder +from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder #from funasr.modules.mask import subsequent_n_mask from funasr.punctuation.abs_model import AbsPunctuation @@ -73,6 +73,9 @@ class TargetDelayTransformer(AbsPunctuation): y = self.decoder(h) return y, None + def with_vad(self): + return False + def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: """Score new token. diff --git a/funasr/punctuation/vad_realtime_transformer.py b/funasr/punctuation/vad_realtime_transformer.py new file mode 100644 index 000000000..35224f9bd --- /dev/null +++ b/funasr/punctuation/vad_realtime_transformer.py @@ -0,0 +1,132 @@ +from typing import Any +from typing import List +from typing import Tuple + +import torch +import torch.nn as nn + +from funasr.modules.embedding import SinusoidalPositionEncoder +from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder +from funasr.punctuation.abs_model import AbsPunctuation + + +class VadRealtimeTransformer(AbsPunctuation): + + def __init__( + self, + vocab_size: int, + punc_size: int, + pos_enc: str = None, + embed_unit: int = 128, + att_unit: int = 256, + head: int = 2, + unit: int = 1024, + layer: int = 4, + dropout_rate: float = 0.5, + kernel_size: int = 11, + sanm_shfit: int = 0, + ): + super().__init__() + if pos_enc == "sinusoidal": + # pos_enc_class = PositionalEncoding + pos_enc_class = SinusoidalPositionEncoder + elif pos_enc is None: + + def pos_enc_class(*args, **kwargs): + return nn.Sequential() # indentity + + else: + raise ValueError(f"unknown pos-enc option: {pos_enc}") + + self.embed = nn.Embedding(vocab_size, embed_unit) + self.encoder = Encoder( + input_size=embed_unit, + output_size=att_unit, + attention_heads=head, + linear_units=unit, + num_blocks=layer, + dropout_rate=dropout_rate, + input_layer="pe", + # pos_enc_class=pos_enc_class, + padding_idx=0, + kernel_size=kernel_size, + sanm_shfit=sanm_shfit, + ) + self.decoder = nn.Linear(att_unit, punc_size) + + +# def _target_mask(self, ys_in_pad): +# ys_mask = ys_in_pad != 0 +# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) +# return ys_mask.unsqueeze(-2) & m + + def forward(self, input: torch.Tensor, text_lengths: torch.Tensor, + vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Compute loss value from buffer sequences. + + Args: + input (torch.Tensor): Input ids. (batch, len) + hidden (torch.Tensor): Target ids. (batch, len) + + """ + x = self.embed(input) + # mask = self._target_mask(input) + h, _, _ = self.encoder(x, text_lengths, vad_indexes) + y = self.decoder(h) + return y, None + + def with_vad(self): + return True + + def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (vocab_size) + and next state for ys + + """ + y = y.unsqueeze(0) + h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1).squeeze(0) + return logp, cache + + def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, vocab_size)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] + + # batch decoding + h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 5be9089d4..d2a00b2fc 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -1350,10 +1350,12 @@ class AbsTask(ABC): train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, mode="train") valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, mode="eval") elif args.dataset_type == "small": train_iter_factory = cls.build_iter_factory( diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py index 1837b2abc..ea1e10284 100644 --- a/funasr/tasks/punctuation.py +++ b/funasr/tasks/punctuation.py @@ -13,10 +13,11 @@ from typeguard import check_argument_types from typeguard import check_return_type from funasr.datasets.collate_fn import CommonCollateFn -from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor +from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor from funasr.punctuation.abs_model import AbsPunctuation from funasr.punctuation.espnet_model import ESPnetPunctuationModel from funasr.punctuation.target_delay_transformer import TargetDelayTransformer +from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer from funasr.tasks.abs_task import AbsTask from funasr.text.phoneme_tokenizer import g2p_choices from funasr.torch_utils.initialize import initialize @@ -29,11 +30,9 @@ from funasr.utils.types import str_or_none punc_choices = ClassChoices( "punctuation", - classes=dict( - target_delay=TargetDelayTransformer, - ), + classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer), type_check=AbsPunctuation, - default="TargetDelayTransformer", + default="target_delay", ) @@ -56,8 +55,6 @@ class PunctuationTask(AbsTask): # NOTE(kamo): add_arguments(..., required=True) can't be used # to provide --print_config mode. Instead of it, do as required = parser.get_default("required") - #import pdb;pdb.set_trace() - #required += ["token_list"] group.add_argument( "--token_list", @@ -154,7 +151,7 @@ class PunctuationTask(AbsTask): bpemodels = [args.bpemodel, args.bpemodel] text_names = ["text", "punc"] if args.use_preprocessor: - retval = MutliTokenizerCommonPreprocessor( + retval = PuncTrainTokenizerCommonPreprocessor( train=train, token_type=token_types, token_list=token_lists, @@ -182,7 +179,7 @@ class PunctuationTask(AbsTask): def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: - retval = () + retval = ("vad",) return retval @classmethod @@ -197,11 +194,13 @@ class PunctuationTask(AbsTask): args.token_list = token_list.copy() if isinstance(args.punc_list, str): with open(args.punc_list, encoding="utf-8") as f2: - punc_list = [line.rstrip() for line in f2] + pairs = [line.rstrip().split(":") for line in f2] + punc_list = [pair[0] for pair in pairs] + punc_weight_list = [float(pair[1]) for pair in pairs] args.punc_list = punc_list.copy() elif isinstance(args.punc_list, list): - # This is in the inference code path. punc_list = args.punc_list.copy() + punc_weight_list = [1] * len(punc_list) if isinstance(args.token_list, (tuple, list)): token_list = args.token_list.copy() else: @@ -217,7 +216,9 @@ class PunctuationTask(AbsTask): # 2. Build ESPnetModel # Assume the last-id is sos_and_eos - model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf) + if "punc_weight" in args.model_conf: + args.model_conf.pop("punc_weight") + model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) # FIXME(kamo): Should be done in model? # 3. Initialize