funasr1.0 ct-transformer streaming

This commit is contained in:
游雁 2024-01-14 23:21:08 +08:00
parent bdfd27b9e9
commit 99730b35f4
14 changed files with 91 additions and 3167 deletions

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", model_revision="v2.0.1")
inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
vads = inputs.split("|")
rec_result_all = "outputs: "
cache = {}
for vad in vads:
rec_result = model(input=vad, cache=cache)
print(rec_result)
rec_result_all += rec_result[0]['text']
print(rec_result_all)

View File

@ -0,0 +1,10 @@
model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt" \
+output_dir="./outputs/debug" \
+device="cpu"

File diff suppressed because it is too large Load Diff

View File

@ -1,383 +0,0 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.scama.chunk_utilis import overlap_chunk
import numpy as np
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.in_size == self.size:
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
x = residual + attn
else:
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.feed_forward(x)
if not self.normalize_before:
x = self.norm2(x)
return x, cache
@tables.register("encoder_classes", "SANMVadEncoder")
class SANMVadEncoder(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -60,7 +60,7 @@ class CTTransformer(nn.Module):
def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:

View File

@ -14,26 +14,6 @@ def split_to_mini_sentence(words: list, word_limit: int = 20):
return sentences
# def split_words(text: str, **kwargs):
# words = []
# segs = text.split()
# for seg in segs:
# # There is no space in seg.
# current_word = ""
# for c in seg:
# if len(c.encode()) == 1:
# # This is an ASCII char.
# current_word += c
# else:
# # This is a Chinese char.
# if len(current_word) > 0:
# words.append(current_word)
# current_word = ""
# words.append(c)
# if len(current_word) > 0:
# words.append(current_word)
#
# return words
def split_words(text: str, jieba_usr_dict=None, **kwargs):
if jieba_usr_dict:

View File

@ -1,135 +0,0 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
class VadRealtimeTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,
punc_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
kernel_size: int = 11,
sanm_shfit: int = 0,
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
padding_idx=0,
kernel_size=kernel_size,
sanm_shfit=sanm_shfit,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
y = self.decoder(h)
return y, None
def with_vad(self):
return True
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list

File diff suppressed because it is too large Load Diff

View File

@ -12,7 +12,7 @@ import numpy as np
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear

View File

@ -12,11 +12,12 @@ import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.ct_transformer.model import CTTransformer
from funasr.register import tables
@tables.register("model_classes", "CTTransformerStreaming")
class CTTransformerStreaming(nn.Module):
class CTTransformerStreaming(CTTransformer):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@ -24,43 +25,13 @@ class CTTransformerStreaming(nn.Module):
"""
def __init__(
self,
encoder: str = None,
encoder_conf: dict = None,
vocab_size: int = -1,
punc_list: list = None,
punc_weight: list = None,
embed_unit: int = 128,
att_unit: int = 256,
dropout_rate: float = 0.5,
ignore_id: int = -1,
sos: int = 1,
eos: int = 2,
sentence_end_id: int = 3,
*args,
**kwargs,
):
super().__init__()
super().__init__(*args, **kwargs)
punc_size = len(punc_list)
if punc_weight is None:
punc_weight = [1] * punc_size
self.embed = nn.Embedding(vocab_size, embed_unit)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.decoder = nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
self.ignore_id = ignore_id
self.sos = sos
self.eos = eos
self.sentence_end_id = sentence_end_id
def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:
@ -70,146 +41,14 @@ class CTTransformerStreaming(nn.Module):
"""
x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes)
y = self.decoder(h)
return y, None
def with_vad(self):
return False
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
def nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
max_length: Optional[int] = None,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, :text_lengths.max()]
punc = punc[:, :text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
y, _ = self.punc_forward(text, text_lengths, vad_indexes)
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_forward(text, text_lengths)
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
from sklearn.metrics import f1_score
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
indices.squeeze(-1).detach().cpu().numpy(),
average='micro')
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
return True
def forward(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
):
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def generate(self,
data_in,
@ -217,22 +56,20 @@ class CTTransformerStreaming(nn.Module):
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = {},
**kwargs,
):
assert len(data_in) == 1
if len(cache) == 0:
cache["pre_text"] = []
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
vad_indexes = kwargs.get("vad_indexes", None)
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
text = "".join(cache["pre_text"]) + " " + text
split_size = kwargs.get("split_size", 20)
jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
import jieba
jieba.load_userdict(jieba_usr_dict)
jieba_usr_dict = jieba
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens = split_words(text)
tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size)
@ -240,8 +77,9 @@ class CTTransformerStreaming(nn.Module):
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
new_mini_sentence = ""
new_mini_sentence_punc = []
skip_num = 0
sentence_punc_list = []
sentence_words_list = []
cache_pop_trigger_limit = 200
results = []
meta_data = {}
@ -254,6 +92,7 @@ class CTTransformerStreaming(nn.Module):
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
"vad_indexes": torch.from_numpy(np.array([len(cache["pre_text"])], dtype='int32')),
}
data = to_device(data, kwargs["device"])
# y, _ = self.wrapped_model(**data)
@ -288,52 +127,42 @@ class CTTransformerStreaming(nn.Module):
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if (i==0 or self.punc_list[punctuations[i-1]] == "" or self.punc_list[punctuations[i-1]] == "") and len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = mini_sentence[i].capitalize()
if i == 0:
if len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
punc_res = self.punc_list[punctuations[i]]
if len(mini_sentence[i][0].encode()) == 1:
if punc_res == "":
punc_res = ","
elif punc_res == "":
punc_res = "."
elif punc_res == "":
punc_res = "?"
words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
elif new_mini_sentence[-1] == ",":
new_mini_sentence_out = new_mini_sentence[:-1] + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
# keep a punctuations array for punc segment
if punc_array is None:
punc_array = punctuations
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
for i in range(0, len(sentence_words_list)):
if i > 0:
if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
sentence_words_list[i] = " " + sentence_words_list[i]
if skip_num < len(cache["pre_text"]):
skip_num += 1
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
words_with_punc.append(sentence_words_list[i])
if skip_num >= len(cache["pre_text"]):
sentence_punc_list_out.append(sentence_punc_list[i])
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "" or sentence_punc_list[i] == "":
sentenceEnd = i
break
cache["pre_text"] = sentence_words_list[sentenceEnd + 1:]
if sentence_out[-1] in self.punc_list:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
# keep a punctuations array for punc segment
if punc_array is None:
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
result_i = {"key": key[0], "text": sentence_out, "punc_array": punc_array}
results.append(result_i)
return results, meta_data

View File

@ -27,13 +27,13 @@ model_conf:
- 1.0
sentence_end_id: 3
encoder: SANMEncoder
encoder: SANMVadEncoder
encoder_conf:
input_size: 256
output_size: 256
attention_heads: 8
linear_units: 1024
num_blocks: 4
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
@ -41,13 +41,10 @@ encoder_conf:
pos_enc_class: SinusoidalPositionEncoder
normalize_before: true
kernel_size: 11
sanm_shfit: 0
sanm_shfit: 5
selfattention_layer_type: sanm
padding_idx: 0
tokenizer: CharTokenizer
tokenizer_conf:
unk_symbol: <unk>
unk_symbol: <unk>

View File

@ -1,111 +0,0 @@
import re
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences
# def split_words(text: str, **kwargs):
# words = []
# segs = text.split()
# for seg in segs:
# # There is no space in seg.
# current_word = ""
# for c in seg:
# if len(c.encode()) == 1:
# # This is an ASCII char.
# current_word += c
# else:
# # This is a Chinese char.
# if len(current_word) > 0:
# words.append(current_word)
# current_word = ""
# words.append(c)
# if len(current_word) > 0:
# words.append(current_word)
#
# return words
def split_words(text: str, jieba_usr_dict=None, **kwargs):
if jieba_usr_dict:
input_list = text.split()
token_list_all = []
langauge_list = []
token_list_tmp = []
language_flag = None
for token in input_list:
if isEnglish(token) and language_flag == 'Chinese':
token_list_all.append(token_list_tmp)
langauge_list.append('Chinese')
token_list_tmp = []
elif not isEnglish(token) and language_flag == 'English':
token_list_all.append(token_list_tmp)
langauge_list.append('English')
token_list_tmp = []
token_list_tmp.append(token)
if isEnglish(token):
language_flag = 'English'
else:
language_flag = 'Chinese'
if token_list_tmp:
token_list_all.append(token_list_tmp)
langauge_list.append(language_flag)
result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
if language_flag == 'English':
result_list.extend(token_list_tmp)
else:
seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
return result_list
else:
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def isEnglish(text:str):
if re.search('^[a-zA-Z\']+$', text):
return True
else:
return False
def join_chinese_and_english(input_list):
line = ''
for token in input_list:
if isEnglish(token):
line = line + ' ' + token
else:
line = line + token
line = line.strip()
return line

View File

@ -1,135 +0,0 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
class VadRealtimeTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,
punc_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
kernel_size: int = 11,
sanm_shfit: int = 0,
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
padding_idx=0,
kernel_size=kernel_size,
sanm_shfit=sanm_shfit,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
y = self.decoder(h)
return y, None
def with_vad(self):
return True
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list

View File

@ -125,3 +125,4 @@ def load_pretrained_model(
logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
dst_state.update(src_state)
obj.load_state_dict(dst_state)