mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
punctuation:add training code, support largedataset
This commit is contained in:
parent
be7230fd94
commit
ee06cb9c68
43
funasr/bin/punc_train.py
Normal file
43
funasr/bin/punc_train.py
Normal file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from funasr.tasks.punctuation import PunctuationTask
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = PunctuationTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Punctuation list",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
"""
|
||||
punc training.
|
||||
"""
|
||||
PunctuationTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
|
||||
main(args=args)
|
||||
44
funasr/bin/punc_train_vadrealtime.py
Normal file
44
funasr/bin/punc_train_vadrealtime.py
Normal file
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from funasr.tasks.punctuation import PunctuationTask
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = PunctuationTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Punctuation list",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
"""
|
||||
punc training.
|
||||
"""
|
||||
PunctuationTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
assert args.num_worker_count == 1
|
||||
|
||||
main(args=args)
|
||||
@ -34,16 +34,20 @@ def load_seg_dict(seg_dict_file):
|
||||
return seg_dict
|
||||
|
||||
class ArkDataLoader(AbsIterFactory):
|
||||
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
|
||||
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, punc_dict_file=None, mode="train"):
|
||||
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
|
||||
if seg_dict_file is not None:
|
||||
seg_dict = load_seg_dict(seg_dict_file)
|
||||
else:
|
||||
seg_dict = None
|
||||
if punc_dict_file is not None:
|
||||
punc_dict = read_symbol_table(punc_dict_file)
|
||||
else:
|
||||
punc_dict = None
|
||||
self.dataset_conf = dataset_conf
|
||||
logging.info("dataloader config: {}".format(self.dataset_conf))
|
||||
batch_mode = self.dataset_conf.get("batch_mode", "padding")
|
||||
self.dataset = Dataset(data_list, symbol_table, seg_dict,
|
||||
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict,
|
||||
self.dataset_conf, mode=mode, batch_mode=batch_mode)
|
||||
|
||||
def build_iter(self, epoch, shuffle=True):
|
||||
|
||||
@ -127,14 +127,17 @@ class AudioDataset(IterableDataset):
|
||||
sample_dict["key"] = key
|
||||
else:
|
||||
text = item
|
||||
sample_dict[data_name] = text.strip().split()[1:]
|
||||
segs = text.strip().split()
|
||||
sample_dict[data_name] = segs[1:]
|
||||
if "key" not in sample_dict:
|
||||
sample_dict["key"] = segs[0]
|
||||
yield sample_dict
|
||||
|
||||
self.close_reader(reader_list)
|
||||
|
||||
|
||||
def len_fn_example(data):
|
||||
return len(data)
|
||||
return 1
|
||||
|
||||
|
||||
def len_fn_token(data):
|
||||
@ -148,6 +151,7 @@ def len_fn_token(data):
|
||||
def Dataset(data_list_file,
|
||||
dict,
|
||||
seg_dict,
|
||||
punc_dict,
|
||||
conf,
|
||||
mode="train",
|
||||
batch_mode="padding"):
|
||||
@ -162,7 +166,7 @@ def Dataset(data_list_file,
|
||||
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
|
||||
|
||||
if "text" in data_names:
|
||||
vocab = {'vocab': dict, 'seg_dict': seg_dict}
|
||||
vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict}
|
||||
tokenize_fn = partial(tokenize, **vocab)
|
||||
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
|
||||
|
||||
@ -191,6 +195,10 @@ def Dataset(data_list_file,
|
||||
sort_size=sort_size,
|
||||
batch_mode=batch_mode)
|
||||
|
||||
dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
|
||||
int_pad_value = conf.get("int_pad_value", -1)
|
||||
float_pad_value = conf.get("float_pad_value", 0.0)
|
||||
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
|
||||
padding_fn = partial(padding, **padding_conf)
|
||||
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
|
||||
|
||||
return dataset
|
||||
|
||||
@ -6,9 +6,8 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
def padding(data, float_pad_value=0.0, int_pad_value=-1):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
assert "speech" in data[0]
|
||||
assert "text" in data[0]
|
||||
|
||||
assert "speech" in data[0] or "text" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
|
||||
@ -31,22 +31,43 @@ def seg_tokenize(txt, seg_dict):
|
||||
|
||||
def tokenize(data,
|
||||
vocab=None,
|
||||
seg_dict=None):
|
||||
seg_dict=None,
|
||||
punc_dict=None):
|
||||
assert "text" in data
|
||||
assert isinstance(vocab, dict)
|
||||
text = data["text"]
|
||||
token = []
|
||||
vad = -2
|
||||
|
||||
if seg_dict is not None:
|
||||
assert isinstance(seg_dict, dict)
|
||||
txt = forward_segment("".join(text).lower(), seg_dict)
|
||||
text = seg_tokenize(txt, seg_dict)
|
||||
|
||||
for x in text:
|
||||
if x in vocab:
|
||||
|
||||
length = len(text)
|
||||
for i in range(length):
|
||||
x = text[i]
|
||||
if i == length-1 and "punc" in data and text[i].startswith("vad:"):
|
||||
vad = x[-1][4:]
|
||||
if len(vad) == 0:
|
||||
vad = -1
|
||||
else:
|
||||
vad = int(vad)
|
||||
elif x in vocab:
|
||||
token.append(vocab[x])
|
||||
else:
|
||||
token.append(vocab['<unk>'])
|
||||
|
||||
if "punc" in data and punc_dict is not None:
|
||||
punc_token = []
|
||||
for punc in data["punc"]:
|
||||
if punc in punc_dict:
|
||||
punc_token.append(punc_dict[punc])
|
||||
else:
|
||||
punc_token.append(punc_dict["_"])
|
||||
data["punc"] = np.array(punc_token)
|
||||
|
||||
data["text"] = np.array(token)
|
||||
if vad is not -2:
|
||||
data["vad_indexes"]=np.array([vad], dtype=np.int64)
|
||||
return data
|
||||
|
||||
@ -704,3 +704,103 @@ class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
del data[self.split_text_name]
|
||||
return result
|
||||
|
||||
class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: List[str] = [None],
|
||||
token_list: List[Union[Path, str, Iterable[str]]] = [None],
|
||||
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: List[str] = ["text"],
|
||||
vad_name: str = "vad_indexes",
|
||||
):
|
||||
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
|
||||
super().__init__(
|
||||
train=train,
|
||||
token_type=token_type[0],
|
||||
token_list=token_list[0],
|
||||
bpemodel=bpemodel[0],
|
||||
text_cleaner=text_cleaner,
|
||||
g2p_type=g2p_type,
|
||||
unk_symbol=unk_symbol,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
delimiter=delimiter,
|
||||
speech_name=speech_name,
|
||||
text_name=text_name[0],
|
||||
rir_scp=rir_scp,
|
||||
rir_apply_prob=rir_apply_prob,
|
||||
noise_scp=noise_scp,
|
||||
noise_apply_prob=noise_apply_prob,
|
||||
noise_db_range=noise_db_range,
|
||||
speech_volume_normalize=speech_volume_normalize,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
|
||||
), "token_type, token_list, bpemodel, or processing text_name mismatched"
|
||||
self.num_tokenizer = len(token_type)
|
||||
self.tokenizer = []
|
||||
self.token_id_converter = []
|
||||
|
||||
for i in range(self.num_tokenizer):
|
||||
if token_type[i] is not None:
|
||||
if token_list[i] is None:
|
||||
raise ValueError("token_list is required if token_type is not None")
|
||||
|
||||
self.tokenizer.append(
|
||||
build_tokenizer(
|
||||
token_type=token_type[i],
|
||||
bpemodel=bpemodel[i],
|
||||
delimiter=delimiter,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
g2p_type=g2p_type,
|
||||
)
|
||||
)
|
||||
self.token_id_converter.append(
|
||||
TokenIDConverter(
|
||||
token_list=token_list[i],
|
||||
unk_symbol=unk_symbol,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.tokenizer.append(None)
|
||||
self.token_id_converter.append(None)
|
||||
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
self.text_name = text_name # override the text_name from CommonPreprocessor
|
||||
self.vad_name = vad_name
|
||||
|
||||
def _text_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
for i in range(self.num_tokenizer):
|
||||
text_name = self.text_name[i]
|
||||
if text_name in data and self.tokenizer[i] is not None:
|
||||
text = data[text_name]
|
||||
text = self.text_cleaner(text)
|
||||
tokens = self.tokenizer[i].text2tokens(text)
|
||||
if "vad:" in tokens[-1]:
|
||||
vad = tokens[-1][4:]
|
||||
tokens = tokens[:-1]
|
||||
if len(vad) == 0:
|
||||
vad = -1
|
||||
else:
|
||||
vad = int(vad)
|
||||
data[self.vad_name] = np.array([vad], dtype=np.int64)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
|
||||
@ -439,6 +439,18 @@ class MultiHeadedAttentionSANM(nn.Module):
|
||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
|
||||
@ -33,3 +33,20 @@ def target_mask(ys_in_pad, ignore_id):
|
||||
ys_mask = ys_in_pad != ignore_id
|
||||
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
||||
return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for decoder self-attention.
|
||||
|
||||
:param int size: size of mask
|
||||
:param int vad_pos: index of vad index
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor (B, Lmax, Lmax)
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
if vad_pos <= 0 or vad_pos >= size:
|
||||
return ret
|
||||
sub_corner = torch.zeros(
|
||||
vad_pos - 1, size - vad_pos, device=device, dtype=dtype)
|
||||
ret[0:vad_pos - 1, vad_pos:] = sub_corner
|
||||
return ret
|
||||
|
||||
@ -25,3 +25,7 @@ class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def with_vad(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -14,15 +14,18 @@ from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
|
||||
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
|
||||
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.punc_model = punc_model
|
||||
self.punc_weight = torch.Tensor(punc_weight)
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
if self.punc_model.with_vad():
|
||||
print("This is a vad puncuation model.")
|
||||
|
||||
def nll(
|
||||
self,
|
||||
@ -31,6 +34,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
@ -49,19 +54,16 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
punc = punc[:, :max_length]
|
||||
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
||||
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
|
||||
#x = F.pad(text, [1, 0], "constant", self.eos)
|
||||
#t = F.pad(text, [0, 1], "constant", self.ignore_id)
|
||||
#for i, l in enumerate(text_lengths):
|
||||
# t[i, l] = self.sos
|
||||
#x_lengths = text_lengths + 1
|
||||
|
||||
if self.punc_model.with_vad():
|
||||
# Should be VadRealtimeTransformer
|
||||
assert vad_indexes is not None
|
||||
y, _ = self.punc_model(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
# Should be TargetDelayTransformer,
|
||||
y, _ = self.punc_model(text, text_lengths)
|
||||
|
||||
# 2. Forward Language model
|
||||
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
|
||||
y, _ = self.punc_model(text, text_lengths)
|
||||
|
||||
# 3. Calc negative log likelihood
|
||||
# Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
if self.training == False:
|
||||
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||
@ -72,7 +74,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
|
||||
return nll, text_lengths
|
||||
else:
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), reduction="none", ignore_index=self.ignore_id)
|
||||
self.punc_weight = self.punc_weight.to(punc.device)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
|
||||
@ -130,9 +133,16 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
punc: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
punc_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None,
|
||||
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
@ -145,5 +155,12 @@ class ESPnetPunctuationModel(AbsESPnetModel):
|
||||
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
return self.punc_model(text, text_lengths)
|
||||
def inference(self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
|
||||
if self.punc_model.with_vad():
|
||||
assert vad_indexes is not None
|
||||
return self.punc_model(text, text_lengths, vad_indexes)
|
||||
else:
|
||||
return self.punc_model(text, text_lengths)
|
||||
|
||||
590
funasr/punctuation/sanm_encoder.py
Normal file
590
funasr/punctuation/sanm_encoder.py
Normal file
@ -0,0 +1,590 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.mask import subsequent_mask, vad_mask
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerSANM, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
encoder_outs = self.encoders0(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
class SANMVadEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
|
||||
no_future_masks = masks & sub_masks
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
|
||||
f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
mask_tup0 = [masks, no_future_masks]
|
||||
encoder_outs = self.encoders0(xs_pad, mask_tup0)
|
||||
xs_pad, _ = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
#if len(self.interctc_layer_idx) == 0:
|
||||
if False:
|
||||
# Here, we should not use the repeat operation to do it for all layers.
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
if layer_idx + 1 == len(self.encoders):
|
||||
# This is last layer.
|
||||
coner_mask = torch.ones(masks.size(0),
|
||||
masks.size(-1),
|
||||
masks.size(-1),
|
||||
device=xs_pad.device,
|
||||
dtype=torch.bool)
|
||||
for word_index, length in enumerate(ilens):
|
||||
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
|
||||
vad_indexes[word_index],
|
||||
device=xs_pad.device)
|
||||
layer_mask = masks & coner_mask
|
||||
else:
|
||||
layer_mask = no_future_masks
|
||||
mask_tup1 = [masks, layer_mask]
|
||||
encoder_outs = encoder_layer(xs_pad, mask_tup1)
|
||||
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
|
||||
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
|
||||
#from funasr.modules.mask import subsequent_n_mask
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
|
||||
@ -73,6 +73,9 @@ class TargetDelayTransformer(AbsPunctuation):
|
||||
y = self.decoder(h)
|
||||
return y, None
|
||||
|
||||
def with_vad(self):
|
||||
return False
|
||||
|
||||
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
|
||||
132
funasr/punctuation/vad_realtime_transformer.py
Normal file
132
funasr/punctuation/vad_realtime_transformer.py
Normal file
@ -0,0 +1,132 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class VadRealtimeTransformer(AbsPunctuation):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
punc_size: int,
|
||||
pos_enc: str = None,
|
||||
embed_unit: int = 128,
|
||||
att_unit: int = 256,
|
||||
head: int = 2,
|
||||
unit: int = 1024,
|
||||
layer: int = 4,
|
||||
dropout_rate: float = 0.5,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
if pos_enc == "sinusoidal":
|
||||
# pos_enc_class = PositionalEncoding
|
||||
pos_enc_class = SinusoidalPositionEncoder
|
||||
elif pos_enc is None:
|
||||
|
||||
def pos_enc_class(*args, **kwargs):
|
||||
return nn.Sequential() # indentity
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||
self.encoder = Encoder(
|
||||
input_size=embed_unit,
|
||||
output_size=att_unit,
|
||||
attention_heads=head,
|
||||
linear_units=unit,
|
||||
num_blocks=layer,
|
||||
dropout_rate=dropout_rate,
|
||||
input_layer="pe",
|
||||
# pos_enc_class=pos_enc_class,
|
||||
padding_idx=0,
|
||||
kernel_size=kernel_size,
|
||||
sanm_shfit=sanm_shfit,
|
||||
)
|
||||
self.decoder = nn.Linear(att_unit, punc_size)
|
||||
|
||||
|
||||
# def _target_mask(self, ys_in_pad):
|
||||
# ys_mask = ys_in_pad != 0
|
||||
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
|
||||
# return ys_mask.unsqueeze(-2) & m
|
||||
|
||||
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
|
||||
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(input)
|
||||
# mask = self._target_mask(input)
|
||||
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
|
||||
y = self.decoder(h)
|
||||
return y, None
|
||||
|
||||
def with_vad(self):
|
||||
return True
|
||||
|
||||
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||
"""Score new token.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||
state: Scorer state for prefix tokens
|
||||
x (torch.Tensor): encoder feature that generates ys.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Any]: Tuple of
|
||||
torch.float32 scores for next token (vocab_size)
|
||||
and next state for ys
|
||||
|
||||
"""
|
||||
y = y.unsqueeze(0)
|
||||
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||
return logp, cache
|
||||
|
||||
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.encoder.encoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||
|
||||
# batch decoding
|
||||
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||
h = self.decoder(h[:, -1])
|
||||
logp = h.log_softmax(dim=-1)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
@ -1350,10 +1350,12 @@ class AbsTask(ABC):
|
||||
train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args,
|
||||
"seg_dict_file") else None,
|
||||
punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
|
||||
mode="train")
|
||||
valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args,
|
||||
"seg_dict_file") else None,
|
||||
punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
|
||||
mode="eval")
|
||||
elif args.dataset_type == "small":
|
||||
train_iter_factory = cls.build_iter_factory(
|
||||
|
||||
@ -13,10 +13,11 @@ from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor
|
||||
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
|
||||
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
@ -29,11 +30,9 @@ from funasr.utils.types import str_or_none
|
||||
|
||||
punc_choices = ClassChoices(
|
||||
"punctuation",
|
||||
classes=dict(
|
||||
target_delay=TargetDelayTransformer,
|
||||
),
|
||||
classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
|
||||
type_check=AbsPunctuation,
|
||||
default="TargetDelayTransformer",
|
||||
default="target_delay",
|
||||
)
|
||||
|
||||
|
||||
@ -56,8 +55,6 @@ class PunctuationTask(AbsTask):
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
required = parser.get_default("required")
|
||||
#import pdb;pdb.set_trace()
|
||||
#required += ["token_list"]
|
||||
|
||||
group.add_argument(
|
||||
"--token_list",
|
||||
@ -154,7 +151,7 @@ class PunctuationTask(AbsTask):
|
||||
bpemodels = [args.bpemodel, args.bpemodel]
|
||||
text_names = ["text", "punc"]
|
||||
if args.use_preprocessor:
|
||||
retval = MutliTokenizerCommonPreprocessor(
|
||||
retval = PuncTrainTokenizerCommonPreprocessor(
|
||||
train=train,
|
||||
token_type=token_types,
|
||||
token_list=token_lists,
|
||||
@ -182,7 +179,7 @@ class PunctuationTask(AbsTask):
|
||||
def optional_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
retval = ("vad",)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
@ -197,11 +194,13 @@ class PunctuationTask(AbsTask):
|
||||
args.token_list = token_list.copy()
|
||||
if isinstance(args.punc_list, str):
|
||||
with open(args.punc_list, encoding="utf-8") as f2:
|
||||
punc_list = [line.rstrip() for line in f2]
|
||||
pairs = [line.rstrip().split(":") for line in f2]
|
||||
punc_list = [pair[0] for pair in pairs]
|
||||
punc_weight_list = [float(pair[1]) for pair in pairs]
|
||||
args.punc_list = punc_list.copy()
|
||||
elif isinstance(args.punc_list, list):
|
||||
# This is in the inference code path.
|
||||
punc_list = args.punc_list.copy()
|
||||
punc_weight_list = [1] * len(punc_list)
|
||||
if isinstance(args.token_list, (tuple, list)):
|
||||
token_list = args.token_list.copy()
|
||||
else:
|
||||
@ -217,7 +216,9 @@ class PunctuationTask(AbsTask):
|
||||
|
||||
# 2. Build ESPnetModel
|
||||
# Assume the last-id is sos_and_eos
|
||||
model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf)
|
||||
if "punc_weight" in args.model_conf:
|
||||
args.model_conf.pop("punc_weight")
|
||||
model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
|
||||
|
||||
# FIXME(kamo): Should be done in model?
|
||||
# 3. Initialize
|
||||
|
||||
Loading…
Reference in New Issue
Block a user