funasr1.0

This commit is contained in:
游雁 2024-01-13 23:43:17 +08:00
parent 835369d631
commit bdfd27b9e9
16 changed files with 2147 additions and 18 deletions

View File

@ -8,7 +8,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0", model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.1", punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
@ -21,7 +21,7 @@ print(res)
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0", model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.1", punc_model_revision="v2.0.1",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",

View File

@ -5,7 +5,7 @@
from funasr import AutoModel from funasr import AutoModel
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.0") model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs") res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
print(res) print(res)

View File

@ -7,7 +7,7 @@ from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
chunk_size = 60000 # ms chunk_size = 60000 # ms
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.1") model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2")
res = model(input=wav_file, chunk_size=chunk_size, ) res = model(input=wav_file, chunk_size=chunk_size, )
print(res) print(res)

View File

@ -1,7 +1,7 @@
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.1" model_revision="v2.0.2"
python funasr/bin/inference.py \ python funasr/bin/inference.py \
+model=${model} \ +model=${model} \

View File

@ -8,7 +8,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0", model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.1", punc_model_revision="v2.0.1",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common", spk_model="damo/speech_campplus_sv_zh-cn_16k-common",

View File

@ -2,7 +2,7 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.0" model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1" vad_model_revision="v2.0.2"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.1" punc_model_revision="v2.0.1"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common" spk_model="damo/speech_campplus_sv_zh-cn_16k-common"

View File

@ -8,7 +8,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0", model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.1", vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.1", punc_model_revision="v2.0.1",
) )

View File

@ -2,7 +2,7 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.0" model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.1" vad_model_revision="v2.0.2"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.1" punc_model_revision="v2.0.1"

View File

@ -22,8 +22,8 @@ def download_from_ms(**kwargs):
config = os.path.join(model_or_path, "config.yaml") config = os.path.join(model_or_path, "config.yaml")
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")): if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
cfg = OmegaConf.load(config) config = OmegaConf.load(config)
kwargs = OmegaConf.merge(cfg, kwargs) kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb") init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")): if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
@ -34,7 +34,7 @@ def download_from_ms(**kwargs):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")): if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
kwargs["model"] = cfg["model"] kwargs["model"] = config["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")): if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")): if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
@ -43,14 +43,30 @@ def download_from_ms(**kwargs):
assert os.path.exists(os.path.join(model_or_path, "configuration.json")) assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
conf_json = json.load(f) conf_json = json.load(f)
config = os.path.join(model_or_path, conf_json["model_config"]) cfg = {}
cfg = OmegaConf.load(config) add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
kwargs = OmegaConf.merge(cfg, kwargs) cfg.update(kwargs)
init_param = os.path.join(model_or_path, conf_json["model_file"]) config = OmegaConf.load(cfg["config"])
kwargs["init_param"] = init_param kwargs = OmegaConf.merge(config, cfg)
kwargs["model"] = cfg["model"] kwargs["model"] = config["model"]
return OmegaConf.to_container(kwargs, resolve=True) return OmegaConf.to_container(kwargs, resolve=True)
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
if isinstance(file_path_metas, dict):
for k, v in file_path_metas.items():
if isinstance(v, str):
p = os.path.join(model_or_path, v)
if os.path.exists(p):
cfg[k] = p
elif isinstance(v, dict):
if k not in cfg:
cfg[k] = {}
return add_file_root_path(model_or_path, v, cfg[k])
return cfg
def get_or_download_model_dir( def get_or_download_model_dir(
model, model,
model_revision=None, model_revision=None,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,383 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.scama.chunk_utilis import overlap_chunk
import numpy as np
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.in_size == self.size:
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
x = residual + attn
else:
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.feed_forward(x)
if not self.normalize_before:
x = self.norm2(x)
return x, cache
@tables.register("encoder_classes", "SANMVadEncoder")
class SANMVadEncoder(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -0,0 +1,340 @@
from typing import Any
from typing import List
from typing import Tuple
from typing import Optional
import numpy as np
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.train_utils.device_funcs import force_gatherable
from funasr.train_utils.device_funcs import to_device
import torch
import torch.nn as nn
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.register import tables
@tables.register("model_classes", "CTTransformerStreaming")
class CTTransformerStreaming(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
encoder: str = None,
encoder_conf: dict = None,
vocab_size: int = -1,
punc_list: list = None,
punc_weight: list = None,
embed_unit: int = 128,
att_unit: int = 256,
dropout_rate: float = 0.5,
ignore_id: int = -1,
sos: int = 1,
eos: int = 2,
sentence_end_id: int = 3,
**kwargs,
):
super().__init__()
punc_size = len(punc_list)
if punc_weight is None:
punc_weight = [1] * punc_size
self.embed = nn.Embedding(vocab_size, embed_unit)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.decoder = nn.Linear(att_unit, punc_size)
self.encoder = encoder
self.punc_list = punc_list
self.punc_weight = punc_weight
self.ignore_id = ignore_id
self.sos = sos
self.eos = eos
self.sentence_end_id = sentence_end_id
def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
def with_vad(self):
return False
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
def nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
max_length: Optional[int] = None,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, :text_lengths.max()]
punc = punc[:, :text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
y, _ = self.punc_forward(text, text_lengths, vad_indexes)
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_forward(text, text_lengths)
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
from sklearn.metrics import f1_score
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
indices.squeeze(-1).detach().cpu().numpy(),
average='micro')
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
def forward(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
):
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def generate(self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
assert len(data_in) == 1
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
vad_indexes = kwargs.get("vad_indexes", None)
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
import jieba
jieba.load_userdict(jieba_usr_dict)
jieba_usr_dict = jieba
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size)
mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
results = []
meta_data = {}
punc_array = None
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
}
data = to_device(data, kwargs["device"])
# y, _ = self.wrapped_model(**data)
y, _ = self.punc_forward(**data)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.sentence_end_id
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
# if len(punctuations) == 0:
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if (i==0 or self.punc_list[punctuations[i-1]] == "" or self.punc_list[punctuations[i-1]] == "") and len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = mini_sentence[i].capitalize()
if i == 0:
if len(mini_sentence[i][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
punc_res = self.punc_list[punctuations[i]]
if len(mini_sentence[i][0].encode()) == 1:
if punc_res == "":
punc_res = ","
elif punc_res == "":
punc_res = "."
elif punc_res == "":
punc_res = "?"
words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
elif new_mini_sentence[-1] == ",":
new_mini_sentence_out = new_mini_sentence[:-1] + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
# keep a punctuations array for punc segment
if punc_array is None:
punc_array = punctuations
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i)
return results, meta_data

View File

@ -0,0 +1,53 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
model: CTTransformerStreaming
model_conf:
ignore_id: 0
embed_unit: 256
att_unit: 256
dropout_rate: 0.1
punc_list:
- <unk>
- _
-
-
-
-
punc_weight:
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
sentence_end_id: 3
encoder: SANMEncoder
encoder_conf:
input_size: 256
output_size: 256
attention_heads: 8
linear_units: 1024
num_blocks: 4
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: pe
pos_enc_class: SinusoidalPositionEncoder
normalize_before: true
kernel_size: 11
sanm_shfit: 0
selfattention_layer_type: sanm
padding_idx: 0
tokenizer: CharTokenizer
tokenizer_conf:
unk_symbol: <unk>

View File

@ -0,0 +1,111 @@
import re
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences
# def split_words(text: str, **kwargs):
# words = []
# segs = text.split()
# for seg in segs:
# # There is no space in seg.
# current_word = ""
# for c in seg:
# if len(c.encode()) == 1:
# # This is an ASCII char.
# current_word += c
# else:
# # This is a Chinese char.
# if len(current_word) > 0:
# words.append(current_word)
# current_word = ""
# words.append(c)
# if len(current_word) > 0:
# words.append(current_word)
#
# return words
def split_words(text: str, jieba_usr_dict=None, **kwargs):
if jieba_usr_dict:
input_list = text.split()
token_list_all = []
langauge_list = []
token_list_tmp = []
language_flag = None
for token in input_list:
if isEnglish(token) and language_flag == 'Chinese':
token_list_all.append(token_list_tmp)
langauge_list.append('Chinese')
token_list_tmp = []
elif not isEnglish(token) and language_flag == 'English':
token_list_all.append(token_list_tmp)
langauge_list.append('English')
token_list_tmp = []
token_list_tmp.append(token)
if isEnglish(token):
language_flag = 'English'
else:
language_flag = 'Chinese'
if token_list_tmp:
token_list_all.append(token_list_tmp)
langauge_list.append(language_flag)
result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
if language_flag == 'English':
result_list.extend(token_list_tmp)
else:
seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
return result_list
else:
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def isEnglish(text:str):
if re.search('^[a-zA-Z\']+$', text):
return True
else:
return False
def join_chinese_and_english(input_list):
line = ''
for token in input_list:
if isEnglish(token):
line = line + ' ' + token
else:
line = line + token
line = line.strip()
return line

View File

@ -0,0 +1,135 @@
from typing import Any
from typing import List
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
class VadRealtimeTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,
punc_size: int,
pos_enc: str = None,
embed_unit: int = 128,
att_unit: int = 256,
head: int = 2,
unit: int = 1024,
layer: int = 4,
dropout_rate: float = 0.5,
kernel_size: int = 11,
sanm_shfit: int = 0,
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
padding_idx=0,
kernel_size=kernel_size,
sanm_shfit=sanm_shfit,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
y = self.decoder(h)
return y, None
def with_vad(self):
return True
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list