mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
funasr1.0
This commit is contained in:
parent
835369d631
commit
bdfd27b9e9
@ -8,7 +8,7 @@ from funasr import AutoModel
|
|||||||
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
model_revision="v2.0.0",
|
model_revision="v2.0.0",
|
||||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
vad_model_revision="v2.0.1",
|
vad_model_revision="v2.0.2",
|
||||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||||
punc_model_revision="v2.0.1",
|
punc_model_revision="v2.0.1",
|
||||||
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
|
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
|
||||||
@ -21,7 +21,7 @@ print(res)
|
|||||||
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
model_revision="v2.0.0",
|
model_revision="v2.0.0",
|
||||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
vad_model_revision="v2.0.1",
|
vad_model_revision="v2.0.2",
|
||||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||||
punc_model_revision="v2.0.1",
|
punc_model_revision="v2.0.1",
|
||||||
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
|
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
|
|
||||||
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.0")
|
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
|
||||||
|
|
||||||
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
|
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
|
||||||
print(res)
|
print(res)
|
||||||
@ -7,7 +7,7 @@ from funasr import AutoModel
|
|||||||
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
|
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
|
||||||
|
|
||||||
chunk_size = 60000 # ms
|
chunk_size = 60000 # ms
|
||||||
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.1")
|
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2")
|
||||||
|
|
||||||
res = model(input=wav_file, chunk_size=chunk_size, )
|
res = model(input=wav_file, chunk_size=chunk_size, )
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
|
|
||||||
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||||
model_revision="v2.0.1"
|
model_revision="v2.0.2"
|
||||||
|
|
||||||
python funasr/bin/inference.py \
|
python funasr/bin/inference.py \
|
||||||
+model=${model} \
|
+model=${model} \
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from funasr import AutoModel
|
|||||||
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
model_revision="v2.0.0",
|
model_revision="v2.0.0",
|
||||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
vad_model_revision="v2.0.1",
|
vad_model_revision="v2.0.2",
|
||||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||||
punc_model_revision="v2.0.1",
|
punc_model_revision="v2.0.1",
|
||||||
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||||
model_revision="v2.0.0"
|
model_revision="v2.0.0"
|
||||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||||
vad_model_revision="v2.0.1"
|
vad_model_revision="v2.0.2"
|
||||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||||
punc_model_revision="v2.0.1"
|
punc_model_revision="v2.0.1"
|
||||||
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
|
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from funasr import AutoModel
|
|||||||
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
model_revision="v2.0.0",
|
model_revision="v2.0.0",
|
||||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
vad_model_revision="v2.0.1",
|
vad_model_revision="v2.0.2",
|
||||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||||
punc_model_revision="v2.0.1",
|
punc_model_revision="v2.0.1",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||||
model_revision="v2.0.0"
|
model_revision="v2.0.0"
|
||||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||||
vad_model_revision="v2.0.1"
|
vad_model_revision="v2.0.2"
|
||||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||||
punc_model_revision="v2.0.1"
|
punc_model_revision="v2.0.1"
|
||||||
|
|
||||||
|
|||||||
@ -22,8 +22,8 @@ def download_from_ms(**kwargs):
|
|||||||
|
|
||||||
config = os.path.join(model_or_path, "config.yaml")
|
config = os.path.join(model_or_path, "config.yaml")
|
||||||
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
|
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
|
||||||
cfg = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
kwargs = OmegaConf.merge(cfg, kwargs)
|
kwargs = OmegaConf.merge(config, kwargs)
|
||||||
init_param = os.path.join(model_or_path, "model.pb")
|
init_param = os.path.join(model_or_path, "model.pb")
|
||||||
kwargs["init_param"] = init_param
|
kwargs["init_param"] = init_param
|
||||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||||
@ -34,7 +34,7 @@ def download_from_ms(**kwargs):
|
|||||||
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
||||||
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
||||||
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
||||||
kwargs["model"] = cfg["model"]
|
kwargs["model"] = config["model"]
|
||||||
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
||||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||||
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
||||||
@ -43,14 +43,30 @@ def download_from_ms(**kwargs):
|
|||||||
assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
|
assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
|
||||||
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
|
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
|
||||||
conf_json = json.load(f)
|
conf_json = json.load(f)
|
||||||
config = os.path.join(model_or_path, conf_json["model_config"])
|
cfg = {}
|
||||||
cfg = OmegaConf.load(config)
|
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
||||||
kwargs = OmegaConf.merge(cfg, kwargs)
|
cfg.update(kwargs)
|
||||||
init_param = os.path.join(model_or_path, conf_json["model_file"])
|
config = OmegaConf.load(cfg["config"])
|
||||||
kwargs["init_param"] = init_param
|
kwargs = OmegaConf.merge(config, cfg)
|
||||||
kwargs["model"] = cfg["model"]
|
kwargs["model"] = config["model"]
|
||||||
return OmegaConf.to_container(kwargs, resolve=True)
|
return OmegaConf.to_container(kwargs, resolve=True)
|
||||||
|
|
||||||
|
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
|
||||||
|
|
||||||
|
if isinstance(file_path_metas, dict):
|
||||||
|
for k, v in file_path_metas.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
p = os.path.join(model_or_path, v)
|
||||||
|
if os.path.exists(p):
|
||||||
|
cfg[k] = p
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
if k not in cfg:
|
||||||
|
cfg[k] = {}
|
||||||
|
return add_file_root_path(model_or_path, v, cfg[k])
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def get_or_download_model_dir(
|
def get_or_download_model_dir(
|
||||||
model,
|
model,
|
||||||
model_revision=None,
|
model_revision=None,
|
||||||
|
|||||||
0
funasr/models/ct_transformer_streaming/__init__.py
Normal file
0
funasr/models/ct_transformer_streaming/__init__.py
Normal file
1091
funasr/models/ct_transformer_streaming/attention.py
Normal file
1091
funasr/models/ct_transformer_streaming/attention.py
Normal file
File diff suppressed because it is too large
Load Diff
383
funasr/models/ct_transformer_streaming/encoder.py
Normal file
383
funasr/models/ct_transformer_streaming/encoder.py
Normal file
@ -0,0 +1,383 @@
|
|||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from funasr.models.scama.chunk_utilis import overlap_chunk
|
||||||
|
import numpy as np
|
||||||
|
from funasr.train_utils.device_funcs import to_device
|
||||||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
|
from funasr.models.sanm.attention import MultiHeadedAttention
|
||||||
|
from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
|
||||||
|
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
|
||||||
|
from funasr.models.transformer.layer_norm import LayerNorm
|
||||||
|
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
|
||||||
|
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
|
||||||
|
from funasr.models.transformer.positionwise_feed_forward import (
|
||||||
|
PositionwiseFeedForward, # noqa: H301
|
||||||
|
)
|
||||||
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
|
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
|
||||||
|
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
|
||||||
|
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
|
||||||
|
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
|
||||||
|
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
||||||
|
from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||||
|
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
|
||||||
|
|
||||||
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
|
from funasr.register import tables
|
||||||
|
|
||||||
|
class EncoderLayerSANM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_size,
|
||||||
|
size,
|
||||||
|
self_attn,
|
||||||
|
feed_forward,
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before=True,
|
||||||
|
concat_after=False,
|
||||||
|
stochastic_depth_rate=0.0,
|
||||||
|
):
|
||||||
|
"""Construct an EncoderLayer object."""
|
||||||
|
super(EncoderLayerSANM, self).__init__()
|
||||||
|
self.self_attn = self_attn
|
||||||
|
self.feed_forward = feed_forward
|
||||||
|
self.norm1 = LayerNorm(in_size)
|
||||||
|
self.norm2 = LayerNorm(size)
|
||||||
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
|
self.in_size = in_size
|
||||||
|
self.size = size
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
self.concat_after = concat_after
|
||||||
|
if self.concat_after:
|
||||||
|
self.concat_linear = nn.Linear(size + size, size)
|
||||||
|
self.stochastic_depth_rate = stochastic_depth_rate
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
|
||||||
|
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||||
|
"""Compute encoded features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||||
|
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||||
|
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor (#batch, time, size).
|
||||||
|
torch.Tensor: Mask tensor (#batch, time).
|
||||||
|
|
||||||
|
"""
|
||||||
|
skip_layer = False
|
||||||
|
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||||
|
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||||
|
stoch_layer_coeff = 1.0
|
||||||
|
if self.training and self.stochastic_depth_rate > 0:
|
||||||
|
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||||
|
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||||
|
|
||||||
|
if skip_layer:
|
||||||
|
if cache is not None:
|
||||||
|
x = torch.cat([cache, x], dim=1)
|
||||||
|
return x, mask
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
if self.concat_after:
|
||||||
|
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||||
|
if self.in_size == self.size:
|
||||||
|
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||||
|
else:
|
||||||
|
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||||
|
else:
|
||||||
|
if self.in_size == self.size:
|
||||||
|
x = residual + stoch_layer_coeff * self.dropout(
|
||||||
|
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = stoch_layer_coeff * self.dropout(
|
||||||
|
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||||
|
)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm2(x)
|
||||||
|
|
||||||
|
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||||
|
|
||||||
|
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
||||||
|
"""Compute encoded features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||||
|
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||||
|
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor (#batch, time, size).
|
||||||
|
torch.Tensor: Mask tensor (#batch, time).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
if self.in_size == self.size:
|
||||||
|
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
||||||
|
x = residual + attn
|
||||||
|
else:
|
||||||
|
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
||||||
|
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = residual + self.feed_forward(x)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm2(x)
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
@tables.register("encoder_classes", "SANMVadEncoder")
|
||||||
|
class SANMVadEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int = 256,
|
||||||
|
attention_heads: int = 4,
|
||||||
|
linear_units: int = 2048,
|
||||||
|
num_blocks: int = 6,
|
||||||
|
dropout_rate: float = 0.1,
|
||||||
|
positional_dropout_rate: float = 0.1,
|
||||||
|
attention_dropout_rate: float = 0.0,
|
||||||
|
input_layer: Optional[str] = "conv2d",
|
||||||
|
pos_enc_class=SinusoidalPositionEncoder,
|
||||||
|
normalize_before: bool = True,
|
||||||
|
concat_after: bool = False,
|
||||||
|
positionwise_layer_type: str = "linear",
|
||||||
|
positionwise_conv_kernel_size: int = 1,
|
||||||
|
padding_idx: int = -1,
|
||||||
|
interctc_layer_idx: List[int] = [],
|
||||||
|
interctc_use_conditioning: bool = False,
|
||||||
|
kernel_size : int = 11,
|
||||||
|
sanm_shfit : int = 0,
|
||||||
|
selfattention_layer_type: str = "sanm",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._output_size = output_size
|
||||||
|
|
||||||
|
if input_layer == "linear":
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(input_size, output_size),
|
||||||
|
torch.nn.LayerNorm(output_size),
|
||||||
|
torch.nn.Dropout(dropout_rate),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
pos_enc_class(output_size, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif input_layer == "conv2d":
|
||||||
|
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||||
|
elif input_layer == "conv2d2":
|
||||||
|
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||||
|
elif input_layer == "conv2d6":
|
||||||
|
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||||
|
elif input_layer == "conv2d8":
|
||||||
|
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||||
|
elif input_layer == "embed":
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||||
|
SinusoidalPositionEncoder(),
|
||||||
|
)
|
||||||
|
elif input_layer is None:
|
||||||
|
if input_size == output_size:
|
||||||
|
self.embed = None
|
||||||
|
else:
|
||||||
|
self.embed = torch.nn.Linear(input_size, output_size)
|
||||||
|
elif input_layer == "pe":
|
||||||
|
self.embed = SinusoidalPositionEncoder()
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown input_layer: " + input_layer)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
if positionwise_layer_type == "linear":
|
||||||
|
positionwise_layer = PositionwiseFeedForward
|
||||||
|
positionwise_layer_args = (
|
||||||
|
output_size,
|
||||||
|
linear_units,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
elif positionwise_layer_type == "conv1d":
|
||||||
|
positionwise_layer = MultiLayeredConv1d
|
||||||
|
positionwise_layer_args = (
|
||||||
|
output_size,
|
||||||
|
linear_units,
|
||||||
|
positionwise_conv_kernel_size,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
elif positionwise_layer_type == "conv1d-linear":
|
||||||
|
positionwise_layer = Conv1dLinear
|
||||||
|
positionwise_layer_args = (
|
||||||
|
output_size,
|
||||||
|
linear_units,
|
||||||
|
positionwise_conv_kernel_size,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Support only linear or conv1d.")
|
||||||
|
|
||||||
|
if selfattention_layer_type == "selfattn":
|
||||||
|
encoder_selfattn_layer = MultiHeadedAttention
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
output_size,
|
||||||
|
attention_dropout_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif selfattention_layer_type == "sanm":
|
||||||
|
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
|
||||||
|
encoder_selfattn_layer_args0 = (
|
||||||
|
attention_heads,
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
attention_dropout_rate,
|
||||||
|
kernel_size,
|
||||||
|
sanm_shfit,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
output_size,
|
||||||
|
output_size,
|
||||||
|
attention_dropout_rate,
|
||||||
|
kernel_size,
|
||||||
|
sanm_shfit,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoders0 = repeat(
|
||||||
|
1,
|
||||||
|
lambda lnum: EncoderLayerSANM(
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||||
|
positionwise_layer(*positionwise_layer_args),
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before,
|
||||||
|
concat_after,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoders = repeat(
|
||||||
|
num_blocks-1,
|
||||||
|
lambda lnum: EncoderLayerSANM(
|
||||||
|
output_size,
|
||||||
|
output_size,
|
||||||
|
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||||
|
positionwise_layer(*positionwise_layer_args),
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before,
|
||||||
|
concat_after,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.normalize_before:
|
||||||
|
self.after_norm = LayerNorm(output_size)
|
||||||
|
|
||||||
|
self.interctc_layer_idx = interctc_layer_idx
|
||||||
|
if len(interctc_layer_idx) > 0:
|
||||||
|
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||||
|
self.interctc_use_conditioning = interctc_use_conditioning
|
||||||
|
self.conditioning_layer = None
|
||||||
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
xs_pad: torch.Tensor,
|
||||||
|
ilens: torch.Tensor,
|
||||||
|
vad_indexes: torch.Tensor,
|
||||||
|
prev_states: torch.Tensor = None,
|
||||||
|
ctc: CTC = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Embed positions in tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs_pad: input tensor (B, L, D)
|
||||||
|
ilens: input length (B)
|
||||||
|
prev_states: Not to be used now.
|
||||||
|
Returns:
|
||||||
|
position embedded tensor and mask
|
||||||
|
"""
|
||||||
|
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||||
|
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
|
||||||
|
no_future_masks = masks & sub_masks
|
||||||
|
xs_pad *= self.output_size()**0.5
|
||||||
|
if self.embed is None:
|
||||||
|
xs_pad = xs_pad
|
||||||
|
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
|
||||||
|
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
|
||||||
|
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||||
|
if short_status:
|
||||||
|
raise TooShortUttError(
|
||||||
|
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
|
||||||
|
f"(it needs more than {limit_size} frames), return empty results",
|
||||||
|
xs_pad.size(1),
|
||||||
|
limit_size,
|
||||||
|
)
|
||||||
|
xs_pad, masks = self.embed(xs_pad, masks)
|
||||||
|
else:
|
||||||
|
xs_pad = self.embed(xs_pad)
|
||||||
|
|
||||||
|
# xs_pad = self.dropout(xs_pad)
|
||||||
|
mask_tup0 = [masks, no_future_masks]
|
||||||
|
encoder_outs = self.encoders0(xs_pad, mask_tup0)
|
||||||
|
xs_pad, _ = encoder_outs[0], encoder_outs[1]
|
||||||
|
intermediate_outs = []
|
||||||
|
|
||||||
|
|
||||||
|
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||||
|
if layer_idx + 1 == len(self.encoders):
|
||||||
|
# This is last layer.
|
||||||
|
coner_mask = torch.ones(masks.size(0),
|
||||||
|
masks.size(-1),
|
||||||
|
masks.size(-1),
|
||||||
|
device=xs_pad.device,
|
||||||
|
dtype=torch.bool)
|
||||||
|
for word_index, length in enumerate(ilens):
|
||||||
|
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
|
||||||
|
vad_indexes[word_index],
|
||||||
|
device=xs_pad.device)
|
||||||
|
layer_mask = masks & coner_mask
|
||||||
|
else:
|
||||||
|
layer_mask = no_future_masks
|
||||||
|
mask_tup1 = [masks, layer_mask]
|
||||||
|
encoder_outs = encoder_layer(xs_pad, mask_tup1)
|
||||||
|
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
xs_pad = self.after_norm(xs_pad)
|
||||||
|
|
||||||
|
olens = masks.squeeze(1).sum(1)
|
||||||
|
if len(intermediate_outs) > 0:
|
||||||
|
return (xs_pad, intermediate_outs), olens, None
|
||||||
|
return xs_pad, olens, None
|
||||||
340
funasr/models/ct_transformer_streaming/model.py
Normal file
340
funasr/models/ct_transformer_streaming/model.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Optional
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
|
from funasr.train_utils.device_funcs import force_gatherable
|
||||||
|
from funasr.train_utils.device_funcs import to_device
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
|
||||||
|
from funasr.utils.load_utils import load_audio_text_image_video
|
||||||
|
|
||||||
|
from funasr.register import tables
|
||||||
|
|
||||||
|
@tables.register("model_classes", "CTTransformerStreaming")
|
||||||
|
class CTTransformerStreaming(nn.Module):
|
||||||
|
"""
|
||||||
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||||
|
https://arxiv.org/pdf/2003.01309.pdf
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: str = None,
|
||||||
|
encoder_conf: dict = None,
|
||||||
|
vocab_size: int = -1,
|
||||||
|
punc_list: list = None,
|
||||||
|
punc_weight: list = None,
|
||||||
|
embed_unit: int = 128,
|
||||||
|
att_unit: int = 256,
|
||||||
|
dropout_rate: float = 0.5,
|
||||||
|
ignore_id: int = -1,
|
||||||
|
sos: int = 1,
|
||||||
|
eos: int = 2,
|
||||||
|
sentence_end_id: int = 3,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
punc_size = len(punc_list)
|
||||||
|
if punc_weight is None:
|
||||||
|
punc_weight = [1] * punc_size
|
||||||
|
|
||||||
|
|
||||||
|
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||||
|
encoder_class = tables.encoder_classes.get(encoder.lower())
|
||||||
|
encoder = encoder_class(**encoder_conf)
|
||||||
|
|
||||||
|
self.decoder = nn.Linear(att_unit, punc_size)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.punc_list = punc_list
|
||||||
|
self.punc_weight = punc_weight
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
self.sos = sos
|
||||||
|
self.eos = eos
|
||||||
|
self.sentence_end_id = sentence_end_id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||||
|
"""Compute loss value from buffer sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): Input ids. (batch, len)
|
||||||
|
hidden (torch.Tensor): Target ids. (batch, len)
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = self.embed(text)
|
||||||
|
# mask = self._target_mask(input)
|
||||||
|
h, _, _ = self.encoder(x, text_lengths)
|
||||||
|
y = self.decoder(h)
|
||||||
|
return y, None
|
||||||
|
|
||||||
|
def with_vad(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||||
|
"""Score new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
x (torch.Tensor): encoder feature that generates ys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, Any]: Tuple of
|
||||||
|
torch.float32 scores for next token (vocab_size)
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||||
|
h = self.decoder(h[:, -1])
|
||||||
|
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||||
|
return logp, cache
|
||||||
|
|
||||||
|
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||||
|
"""Score new token batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||||
|
states (List[Any]): Scorer states for prefix tokens.
|
||||||
|
xs (torch.Tensor):
|
||||||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# merge states
|
||||||
|
n_batch = len(ys)
|
||||||
|
n_layers = len(self.encoder.encoders)
|
||||||
|
if states[0] is None:
|
||||||
|
batch_state = None
|
||||||
|
else:
|
||||||
|
# transpose state of [batch, layer] into [layer, batch]
|
||||||
|
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||||
|
|
||||||
|
# batch decoding
|
||||||
|
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||||
|
h = self.decoder(h[:, -1])
|
||||||
|
logp = h.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
# transpose state of [layer, batch] into [batch, layer]
|
||||||
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||||
|
return logp, state_list
|
||||||
|
|
||||||
|
def nll(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
punc: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
punc_lengths: torch.Tensor,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
vad_indexes: Optional[torch.Tensor] = None,
|
||||||
|
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute negative log likelihood(nll)
|
||||||
|
|
||||||
|
Normally, this function is called in batchify_nll.
|
||||||
|
Args:
|
||||||
|
text: (Batch, Length)
|
||||||
|
punc: (Batch, Length)
|
||||||
|
text_lengths: (Batch,)
|
||||||
|
max_lengths: int
|
||||||
|
"""
|
||||||
|
batch_size = text.size(0)
|
||||||
|
# For data parallel
|
||||||
|
if max_length is None:
|
||||||
|
text = text[:, :text_lengths.max()]
|
||||||
|
punc = punc[:, :text_lengths.max()]
|
||||||
|
else:
|
||||||
|
text = text[:, :max_length]
|
||||||
|
punc = punc[:, :max_length]
|
||||||
|
|
||||||
|
if self.with_vad():
|
||||||
|
# Should be VadRealtimeTransformer
|
||||||
|
assert vad_indexes is not None
|
||||||
|
y, _ = self.punc_forward(text, text_lengths, vad_indexes)
|
||||||
|
else:
|
||||||
|
# Should be TargetDelayTransformer,
|
||||||
|
y, _ = self.punc_forward(text, text_lengths)
|
||||||
|
|
||||||
|
# Calc negative log likelihood
|
||||||
|
# nll: (BxL,)
|
||||||
|
if self.training == False:
|
||||||
|
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||||
|
from sklearn.metrics import f1_score
|
||||||
|
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
|
||||||
|
indices.squeeze(-1).detach().cpu().numpy(),
|
||||||
|
average='micro')
|
||||||
|
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
|
||||||
|
return nll, text_lengths
|
||||||
|
else:
|
||||||
|
self.punc_weight = self.punc_weight.to(punc.device)
|
||||||
|
nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
|
||||||
|
ignore_index=self.ignore_id)
|
||||||
|
# nll: (BxL,) -> (BxL,)
|
||||||
|
if max_length is None:
|
||||||
|
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
|
||||||
|
else:
|
||||||
|
nll.masked_fill_(
|
||||||
|
make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||||
|
0.0,
|
||||||
|
)
|
||||||
|
# nll: (BxL,) -> (B, L)
|
||||||
|
nll = nll.view(batch_size, -1)
|
||||||
|
return nll, text_lengths
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
punc: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
punc_lengths: torch.Tensor,
|
||||||
|
vad_indexes: Optional[torch.Tensor] = None,
|
||||||
|
vad_indexes_lengths: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
|
||||||
|
ntokens = y_lengths.sum()
|
||||||
|
loss = nll.sum() / ntokens
|
||||||
|
stats = dict(loss=loss.detach())
|
||||||
|
|
||||||
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||||
|
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||||
|
return loss, stats, weight
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
data_in,
|
||||||
|
data_lengths=None,
|
||||||
|
key: list = None,
|
||||||
|
tokenizer=None,
|
||||||
|
frontend=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert len(data_in) == 1
|
||||||
|
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
|
||||||
|
vad_indexes = kwargs.get("vad_indexes", None)
|
||||||
|
# text = data_in[0]
|
||||||
|
# text_lengths = data_lengths[0] if data_lengths is not None else None
|
||||||
|
split_size = kwargs.get("split_size", 20)
|
||||||
|
|
||||||
|
jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
|
||||||
|
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
|
||||||
|
import jieba
|
||||||
|
jieba.load_userdict(jieba_usr_dict)
|
||||||
|
jieba_usr_dict = jieba
|
||||||
|
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
|
||||||
|
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
|
||||||
|
tokens_int = tokenizer.encode(tokens)
|
||||||
|
|
||||||
|
mini_sentences = split_to_mini_sentence(tokens, split_size)
|
||||||
|
mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
|
||||||
|
assert len(mini_sentences) == len(mini_sentences_id)
|
||||||
|
cache_sent = []
|
||||||
|
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
|
||||||
|
new_mini_sentence = ""
|
||||||
|
new_mini_sentence_punc = []
|
||||||
|
cache_pop_trigger_limit = 200
|
||||||
|
results = []
|
||||||
|
meta_data = {}
|
||||||
|
punc_array = None
|
||||||
|
for mini_sentence_i in range(len(mini_sentences)):
|
||||||
|
mini_sentence = mini_sentences[mini_sentence_i]
|
||||||
|
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
||||||
|
mini_sentence = cache_sent + mini_sentence
|
||||||
|
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
|
||||||
|
data = {
|
||||||
|
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
|
||||||
|
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
|
||||||
|
}
|
||||||
|
data = to_device(data, kwargs["device"])
|
||||||
|
# y, _ = self.wrapped_model(**data)
|
||||||
|
y, _ = self.punc_forward(**data)
|
||||||
|
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
||||||
|
punctuations = indices
|
||||||
|
if indices.size()[0] != 1:
|
||||||
|
punctuations = torch.squeeze(indices)
|
||||||
|
assert punctuations.size()[0] == len(mini_sentence)
|
||||||
|
|
||||||
|
# Search for the last Period/QuestionMark as cache
|
||||||
|
if mini_sentence_i < len(mini_sentences) - 1:
|
||||||
|
sentenceEnd = -1
|
||||||
|
last_comma_index = -1
|
||||||
|
for i in range(len(punctuations) - 2, 1, -1):
|
||||||
|
if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
|
||||||
|
sentenceEnd = i
|
||||||
|
break
|
||||||
|
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
||||||
|
last_comma_index = i
|
||||||
|
|
||||||
|
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
|
||||||
|
# The sentence it too long, cut off at a comma.
|
||||||
|
sentenceEnd = last_comma_index
|
||||||
|
punctuations[sentenceEnd] = self.sentence_end_id
|
||||||
|
cache_sent = mini_sentence[sentenceEnd + 1:]
|
||||||
|
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
|
||||||
|
mini_sentence = mini_sentence[0:sentenceEnd + 1]
|
||||||
|
punctuations = punctuations[0:sentenceEnd + 1]
|
||||||
|
|
||||||
|
# if len(punctuations) == 0:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
punctuations_np = punctuations.cpu().numpy()
|
||||||
|
new_mini_sentence_punc += [int(x) for x in punctuations_np]
|
||||||
|
words_with_punc = []
|
||||||
|
for i in range(len(mini_sentence)):
|
||||||
|
if (i==0 or self.punc_list[punctuations[i-1]] == "。" or self.punc_list[punctuations[i-1]] == "?") and len(mini_sentence[i][0].encode()) == 1:
|
||||||
|
mini_sentence[i] = mini_sentence[i].capitalize()
|
||||||
|
if i == 0:
|
||||||
|
if len(mini_sentence[i][0].encode()) == 1:
|
||||||
|
mini_sentence[i] = " " + mini_sentence[i]
|
||||||
|
if i > 0:
|
||||||
|
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
|
||||||
|
mini_sentence[i] = " " + mini_sentence[i]
|
||||||
|
words_with_punc.append(mini_sentence[i])
|
||||||
|
if self.punc_list[punctuations[i]] != "_":
|
||||||
|
punc_res = self.punc_list[punctuations[i]]
|
||||||
|
if len(mini_sentence[i][0].encode()) == 1:
|
||||||
|
if punc_res == ",":
|
||||||
|
punc_res = ","
|
||||||
|
elif punc_res == "。":
|
||||||
|
punc_res = "."
|
||||||
|
elif punc_res == "?":
|
||||||
|
punc_res = "?"
|
||||||
|
words_with_punc.append(punc_res)
|
||||||
|
new_mini_sentence += "".join(words_with_punc)
|
||||||
|
# Add Period for the end of the sentence
|
||||||
|
new_mini_sentence_out = new_mini_sentence
|
||||||
|
new_mini_sentence_punc_out = new_mini_sentence_punc
|
||||||
|
if mini_sentence_i == len(mini_sentences) - 1:
|
||||||
|
if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
|
||||||
|
new_mini_sentence_out = new_mini_sentence[:-1] + "。"
|
||||||
|
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
|
||||||
|
elif new_mini_sentence[-1] == ",":
|
||||||
|
new_mini_sentence_out = new_mini_sentence[:-1] + "."
|
||||||
|
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
|
||||||
|
elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0:
|
||||||
|
new_mini_sentence_out = new_mini_sentence + "。"
|
||||||
|
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
|
||||||
|
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
|
||||||
|
new_mini_sentence_out = new_mini_sentence + "."
|
||||||
|
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
|
||||||
|
# keep a punctuations array for punc segment
|
||||||
|
if punc_array is None:
|
||||||
|
punc_array = punctuations
|
||||||
|
else:
|
||||||
|
punc_array = torch.cat([punc_array, punctuations], dim=0)
|
||||||
|
|
||||||
|
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
|
||||||
|
results.append(result_i)
|
||||||
|
|
||||||
|
return results, meta_data
|
||||||
|
|
||||||
53
funasr/models/ct_transformer_streaming/template.yaml
Normal file
53
funasr/models/ct_transformer_streaming/template.yaml
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# This is an example that demonstrates how to configure a model file.
|
||||||
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
|
# to print the register_table:
|
||||||
|
# from funasr.register import tables
|
||||||
|
# tables.print()
|
||||||
|
|
||||||
|
model: CTTransformerStreaming
|
||||||
|
model_conf:
|
||||||
|
ignore_id: 0
|
||||||
|
embed_unit: 256
|
||||||
|
att_unit: 256
|
||||||
|
dropout_rate: 0.1
|
||||||
|
punc_list:
|
||||||
|
- <unk>
|
||||||
|
- _
|
||||||
|
- ,
|
||||||
|
- 。
|
||||||
|
- ?
|
||||||
|
- 、
|
||||||
|
punc_weight:
|
||||||
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
- 1.0
|
||||||
|
sentence_end_id: 3
|
||||||
|
|
||||||
|
encoder: SANMEncoder
|
||||||
|
encoder_conf:
|
||||||
|
input_size: 256
|
||||||
|
output_size: 256
|
||||||
|
attention_heads: 8
|
||||||
|
linear_units: 1024
|
||||||
|
num_blocks: 4
|
||||||
|
dropout_rate: 0.1
|
||||||
|
positional_dropout_rate: 0.1
|
||||||
|
attention_dropout_rate: 0.0
|
||||||
|
input_layer: pe
|
||||||
|
pos_enc_class: SinusoidalPositionEncoder
|
||||||
|
normalize_before: true
|
||||||
|
kernel_size: 11
|
||||||
|
sanm_shfit: 0
|
||||||
|
selfattention_layer_type: sanm
|
||||||
|
padding_idx: 0
|
||||||
|
|
||||||
|
tokenizer: CharTokenizer
|
||||||
|
tokenizer_conf:
|
||||||
|
unk_symbol: <unk>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
111
funasr/models/ct_transformer_streaming/utils.py
Normal file
111
funasr/models/ct_transformer_streaming/utils.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||||
|
assert word_limit > 1
|
||||||
|
if len(words) <= word_limit:
|
||||||
|
return [words]
|
||||||
|
sentences = []
|
||||||
|
length = len(words)
|
||||||
|
sentence_len = length // word_limit
|
||||||
|
for i in range(sentence_len):
|
||||||
|
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||||
|
if length % word_limit > 0:
|
||||||
|
sentences.append(words[sentence_len * word_limit:])
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
|
||||||
|
# def split_words(text: str, **kwargs):
|
||||||
|
# words = []
|
||||||
|
# segs = text.split()
|
||||||
|
# for seg in segs:
|
||||||
|
# # There is no space in seg.
|
||||||
|
# current_word = ""
|
||||||
|
# for c in seg:
|
||||||
|
# if len(c.encode()) == 1:
|
||||||
|
# # This is an ASCII char.
|
||||||
|
# current_word += c
|
||||||
|
# else:
|
||||||
|
# # This is a Chinese char.
|
||||||
|
# if len(current_word) > 0:
|
||||||
|
# words.append(current_word)
|
||||||
|
# current_word = ""
|
||||||
|
# words.append(c)
|
||||||
|
# if len(current_word) > 0:
|
||||||
|
# words.append(current_word)
|
||||||
|
#
|
||||||
|
# return words
|
||||||
|
|
||||||
|
def split_words(text: str, jieba_usr_dict=None, **kwargs):
|
||||||
|
if jieba_usr_dict:
|
||||||
|
input_list = text.split()
|
||||||
|
token_list_all = []
|
||||||
|
langauge_list = []
|
||||||
|
token_list_tmp = []
|
||||||
|
language_flag = None
|
||||||
|
for token in input_list:
|
||||||
|
if isEnglish(token) and language_flag == 'Chinese':
|
||||||
|
token_list_all.append(token_list_tmp)
|
||||||
|
langauge_list.append('Chinese')
|
||||||
|
token_list_tmp = []
|
||||||
|
elif not isEnglish(token) and language_flag == 'English':
|
||||||
|
token_list_all.append(token_list_tmp)
|
||||||
|
langauge_list.append('English')
|
||||||
|
token_list_tmp = []
|
||||||
|
|
||||||
|
token_list_tmp.append(token)
|
||||||
|
|
||||||
|
if isEnglish(token):
|
||||||
|
language_flag = 'English'
|
||||||
|
else:
|
||||||
|
language_flag = 'Chinese'
|
||||||
|
|
||||||
|
if token_list_tmp:
|
||||||
|
token_list_all.append(token_list_tmp)
|
||||||
|
langauge_list.append(language_flag)
|
||||||
|
|
||||||
|
result_list = []
|
||||||
|
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
|
||||||
|
if language_flag == 'English':
|
||||||
|
result_list.extend(token_list_tmp)
|
||||||
|
else:
|
||||||
|
seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
|
||||||
|
result_list.extend(seg_list)
|
||||||
|
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
else:
|
||||||
|
words = []
|
||||||
|
segs = text.split()
|
||||||
|
for seg in segs:
|
||||||
|
# There is no space in seg.
|
||||||
|
current_word = ""
|
||||||
|
for c in seg:
|
||||||
|
if len(c.encode()) == 1:
|
||||||
|
# This is an ASCII char.
|
||||||
|
current_word += c
|
||||||
|
else:
|
||||||
|
# This is a Chinese char.
|
||||||
|
if len(current_word) > 0:
|
||||||
|
words.append(current_word)
|
||||||
|
current_word = ""
|
||||||
|
words.append(c)
|
||||||
|
if len(current_word) > 0:
|
||||||
|
words.append(current_word)
|
||||||
|
return words
|
||||||
|
|
||||||
|
def isEnglish(text:str):
|
||||||
|
if re.search('^[a-zA-Z\']+$', text):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def join_chinese_and_english(input_list):
|
||||||
|
line = ''
|
||||||
|
for token in input_list:
|
||||||
|
if isEnglish(token):
|
||||||
|
line = line + ' ' + token
|
||||||
|
else:
|
||||||
|
line = line + token
|
||||||
|
|
||||||
|
line = line.strip()
|
||||||
|
return line
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
|
||||||
|
from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class VadRealtimeTransformer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||||
|
https://arxiv.org/pdf/2003.01309.pdf
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
punc_size: int,
|
||||||
|
pos_enc: str = None,
|
||||||
|
embed_unit: int = 128,
|
||||||
|
att_unit: int = 256,
|
||||||
|
head: int = 2,
|
||||||
|
unit: int = 1024,
|
||||||
|
layer: int = 4,
|
||||||
|
dropout_rate: float = 0.5,
|
||||||
|
kernel_size: int = 11,
|
||||||
|
sanm_shfit: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if pos_enc == "sinusoidal":
|
||||||
|
# pos_enc_class = PositionalEncoding
|
||||||
|
pos_enc_class = SinusoidalPositionEncoder
|
||||||
|
elif pos_enc is None:
|
||||||
|
|
||||||
|
def pos_enc_class(*args, **kwargs):
|
||||||
|
return nn.Sequential() # indentity
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown pos-enc option: {pos_enc}")
|
||||||
|
|
||||||
|
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||||
|
self.encoder = Encoder(
|
||||||
|
input_size=embed_unit,
|
||||||
|
output_size=att_unit,
|
||||||
|
attention_heads=head,
|
||||||
|
linear_units=unit,
|
||||||
|
num_blocks=layer,
|
||||||
|
dropout_rate=dropout_rate,
|
||||||
|
input_layer="pe",
|
||||||
|
# pos_enc_class=pos_enc_class,
|
||||||
|
padding_idx=0,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
sanm_shfit=sanm_shfit,
|
||||||
|
)
|
||||||
|
self.decoder = nn.Linear(att_unit, punc_size)
|
||||||
|
|
||||||
|
|
||||||
|
# def _target_mask(self, ys_in_pad):
|
||||||
|
# ys_mask = ys_in_pad != 0
|
||||||
|
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
|
||||||
|
# return ys_mask.unsqueeze(-2) & m
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
|
||||||
|
vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||||
|
"""Compute loss value from buffer sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): Input ids. (batch, len)
|
||||||
|
hidden (torch.Tensor): Target ids. (batch, len)
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = self.embed(input)
|
||||||
|
# mask = self._target_mask(input)
|
||||||
|
h, _, _ = self.encoder(x, text_lengths, vad_indexes)
|
||||||
|
y = self.decoder(h)
|
||||||
|
return y, None
|
||||||
|
|
||||||
|
def with_vad(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
|
||||||
|
"""Score new token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
||||||
|
state: Scorer state for prefix tokens
|
||||||
|
x (torch.Tensor): encoder feature that generates ys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, Any]: Tuple of
|
||||||
|
torch.float32 scores for next token (vocab_size)
|
||||||
|
and next state for ys
|
||||||
|
|
||||||
|
"""
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
|
||||||
|
h = self.decoder(h[:, -1])
|
||||||
|
logp = h.log_softmax(dim=-1).squeeze(0)
|
||||||
|
return logp, cache
|
||||||
|
|
||||||
|
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||||
|
"""Score new token batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||||
|
states (List[Any]): Scorer states for prefix tokens.
|
||||||
|
xs (torch.Tensor):
|
||||||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||||
|
batchfied scores for next token with shape of `(n_batch, vocab_size)`
|
||||||
|
and next state list for ys.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# merge states
|
||||||
|
n_batch = len(ys)
|
||||||
|
n_layers = len(self.encoder.encoders)
|
||||||
|
if states[0] is None:
|
||||||
|
batch_state = None
|
||||||
|
else:
|
||||||
|
# transpose state of [batch, layer] into [layer, batch]
|
||||||
|
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
|
||||||
|
|
||||||
|
# batch decoding
|
||||||
|
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
|
||||||
|
h = self.decoder(h[:, -1])
|
||||||
|
logp = h.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
# transpose state of [layer, batch] into [batch, layer]
|
||||||
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||||
|
return logp, state_list
|
||||||
Loading…
Reference in New Issue
Block a user