mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
212 lines
7.4 KiB
Python
212 lines
7.4 KiB
Python
from typing import Any
|
|
from typing import List
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from funasr.register import tables
|
|
|
|
@tables.register("model_classes", "CTTransformer")
|
|
class CTTransformer(nn.Module):
|
|
"""
|
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
|
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
|
https://arxiv.org/pdf/2003.01309.pdf
|
|
"""
|
|
def __init__(
|
|
self,
|
|
encoder: str = None,
|
|
encoder_conf: str = None,
|
|
vocab_size: int = -1,
|
|
punc_list: list = None,
|
|
punc_weight: list = None,
|
|
embed_unit: int = 128,
|
|
att_unit: int = 256,
|
|
dropout_rate: float = 0.5,
|
|
ignore_id: int = -1,
|
|
sos: int = 1,
|
|
eos: int = 2,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
punc_size = len(punc_list)
|
|
if punc_weight is None:
|
|
punc_weight = [1] * punc_size
|
|
|
|
|
|
self.embed = nn.Embedding(vocab_size, embed_unit)
|
|
encoder_class = tables.encoder_classes.get(encoder.lower())
|
|
encoder = encoder_class(**encoder_conf)
|
|
|
|
self.decoder = nn.Linear(att_unit, punc_size)
|
|
self.encoder = encoder
|
|
self.punc_list = punc_list
|
|
self.punc_weight = punc_weight
|
|
self.ignore_id = ignore_id
|
|
self.sos = sos
|
|
self.eos = eos
|
|
|
|
|
|
|
|
def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
|
"""Compute loss value from buffer sequences.
|
|
|
|
Args:
|
|
input (torch.Tensor): Input ids. (batch, len)
|
|
hidden (torch.Tensor): Target ids. (batch, len)
|
|
|
|
"""
|
|
x = self.embed(input)
|
|
# mask = self._target_mask(input)
|
|
h, _, _ = self.encoder(x, text_lengths)
|
|
y = self.decoder(h)
|
|
return y, None
|
|
|
|
def with_vad(self):
|
|
return False
|
|
|
|
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
|
"""Score new token.
|
|
|
|
Args:
|
|
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
|
state: Scorer state for prefix tokens
|
|
x (torch.Tensor): encoder feature that generates ys.
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, Any]: Tuple of
|
|
torch.float32 scores for next token (vocab_size)
|
|
and next state for ys
|
|
|
|
"""
|
|
y = y.unsqueeze(0)
|
|
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
|
h = self.decoder(h[:, -1])
|
|
logp = h.log_softmax(dim=-1).squeeze(0)
|
|
return logp, cache
|
|
|
|
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
|
"""Score new token batch.
|
|
|
|
Args:
|
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
|
states (List[Any]): Scorer states for prefix tokens.
|
|
xs (torch.Tensor):
|
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, List[Any]]: Tuple of
|
|
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
|
and next state list for ys.
|
|
|
|
"""
|
|
# merge states
|
|
n_batch = len(ys)
|
|
n_layers = len(self.encoder.encoders)
|
|
if states[0] is None:
|
|
batch_state = None
|
|
else:
|
|
# transpose state of [batch, layer] into [layer, batch]
|
|
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
|
|
|
# batch decoding
|
|
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
|
h = self.decoder(h[:, -1])
|
|
logp = h.log_softmax(dim=-1)
|
|
|
|
# transpose state of [layer, batch] into [batch, layer]
|
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
|
return logp, state_list
|
|
|
|
def nll(
|
|
self,
|
|
text: torch.Tensor,
|
|
punc: torch.Tensor,
|
|
text_lengths: torch.Tensor,
|
|
punc_lengths: torch.Tensor,
|
|
max_length: Optional[int] = None,
|
|
vad_indexes: Optional[torch.Tensor] = None,
|
|
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Compute negative log likelihood(nll)
|
|
|
|
Normally, this function is called in batchify_nll.
|
|
Args:
|
|
text: (Batch, Length)
|
|
punc: (Batch, Length)
|
|
text_lengths: (Batch,)
|
|
max_lengths: int
|
|
"""
|
|
batch_size = text.size(0)
|
|
# For data parallel
|
|
if max_length is None:
|
|
text = text[:, :text_lengths.max()]
|
|
punc = punc[:, :text_lengths.max()]
|
|
else:
|
|
text = text[:, :max_length]
|
|
punc = punc[:, :max_length]
|
|
|
|
if self.with_vad():
|
|
# Should be VadRealtimeTransformer
|
|
assert vad_indexes is not None
|
|
y, _ = self.punc_forward(text, text_lengths, vad_indexes)
|
|
else:
|
|
# Should be TargetDelayTransformer,
|
|
y, _ = self.punc_forward(text, text_lengths)
|
|
|
|
# Calc negative log likelihood
|
|
# nll: (BxL,)
|
|
if self.training == False:
|
|
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
|
from sklearn.metrics import f1_score
|
|
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
|
|
indices.squeeze(-1).detach().cpu().numpy(),
|
|
average='micro')
|
|
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
|
|
return nll, text_lengths
|
|
else:
|
|
self.punc_weight = self.punc_weight.to(punc.device)
|
|
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
|
|
ignore_index=self.ignore_id)
|
|
# nll: (BxL,) -> (BxL,)
|
|
if max_length is None:
|
|
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
|
|
else:
|
|
nll.masked_fill_(
|
|
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
|
0.0,
|
|
)
|
|
# nll: (BxL,) -> (B, L)
|
|
nll = nll.view(batch_size, -1)
|
|
return nll, text_lengths
|
|
|
|
|
|
def forward(
|
|
self,
|
|
text: torch.Tensor,
|
|
punc: torch.Tensor,
|
|
text_lengths: torch.Tensor,
|
|
punc_lengths: torch.Tensor,
|
|
vad_indexes: Optional[torch.Tensor] = None,
|
|
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
|
|
ntokens = y_lengths.sum()
|
|
loss = nll.sum() / ntokens
|
|
stats = dict(loss=loss.detach())
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
|
return loss, stats, weight
|
|
|
|
def generate(self,
|
|
text: torch.Tensor,
|
|
text_lengths: torch.Tensor,
|
|
vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
|
|
if self.with_vad():
|
|
assert vad_indexes is not None
|
|
return self.punc_forward(text, text_lengths, vad_indexes)
|
|
else:
|
|
return self.punc_forward(text, text_lengths) |