mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
This commit is contained in:
commit
26a2a232a9
@ -42,6 +42,7 @@ from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
|
||||
from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
from funasr.models.encoder.rwkv_encoder import RWKVEncoder
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.default import MultiChannelFrontend
|
||||
from funasr.models.frontend.fused import FusedFrontends
|
||||
@ -119,6 +120,7 @@ encoder_choices = ClassChoices(
|
||||
e_branchformer=EBranchformerEncoder,
|
||||
mfcca_enc=MFCCAEncoder,
|
||||
chunk_conformer=ConformerChunkEncoder,
|
||||
rwkv=RWKVEncoder,
|
||||
),
|
||||
default="rnn",
|
||||
)
|
||||
|
||||
155
funasr/models/encoder/rwkv_encoder.py
Normal file
155
funasr/models/encoder/rwkv_encoder.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""RWKV encoder definition for Transducer models."""
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.rwkv import RWKV
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.rwkv_subsampling import RWKVConvInput
|
||||
from funasr.modules.nets_utils import make_source_mask
|
||||
|
||||
class RWKVEncoder(AbsEncoder):
|
||||
"""RWKV encoder module.
|
||||
|
||||
Based on https://arxiv.org/pdf/2305.13048.pdf.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
output_size: Input/Output size.
|
||||
context_size: Context size for WKV computation.
|
||||
linear_size: FeedForward hidden size.
|
||||
attention_size: SelfAttention hidden size.
|
||||
normalization_type: Normalization layer type.
|
||||
normalization_args: Normalization layer arguments.
|
||||
num_blocks: Number of RWKV blocks.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
att_dropout_rate: Dropout rate for the attention module.
|
||||
ffn_dropout_rate: Dropout rate for the feed-forward module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 512,
|
||||
context_size: int = 1024,
|
||||
linear_size: Optional[int] = None,
|
||||
attention_size: Optional[int] = None,
|
||||
num_blocks: int = 4,
|
||||
att_dropout_rate: float = 0.0,
|
||||
ffn_dropout_rate: float = 0.0,
|
||||
dropout_rate: float = 0.0,
|
||||
subsampling_factor: int =4,
|
||||
time_reduction_factor: int = 1,
|
||||
kernel: int = 3,
|
||||
) -> None:
|
||||
"""Construct a RWKVEncoder object."""
|
||||
super().__init__()
|
||||
|
||||
self.embed = RWKVConvInput(
|
||||
input_size,
|
||||
[output_size//4, output_size//2, output_size],
|
||||
subsampling_factor,
|
||||
conv_kernel_size=kernel,
|
||||
output_size=output_size,
|
||||
)
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
|
||||
linear_size = output_size * 4 if linear_size is None else linear_size
|
||||
attention_size = output_size if attention_size is None else attention_size
|
||||
|
||||
self.rwkv_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
RWKV(
|
||||
output_size,
|
||||
linear_size,
|
||||
attention_size,
|
||||
context_size,
|
||||
block_id,
|
||||
num_blocks,
|
||||
att_dropout_rate=att_dropout_rate,
|
||||
ffn_dropout_rate=ffn_dropout_rate,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
for block_id in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.embed_norm = LayerNorm(output_size)
|
||||
self.final_norm = LayerNorm(output_size)
|
||||
|
||||
self._output_size = output_size
|
||||
self.context_size = context_size
|
||||
|
||||
self.num_blocks = num_blocks
|
||||
self.time_reduction_factor = time_reduction_factor
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: Encoder input sequences. (B, L)
|
||||
|
||||
Returns:
|
||||
out: Encoder output sequences. (B, U, D)
|
||||
|
||||
"""
|
||||
_, length, _ = x.size()
|
||||
|
||||
assert (
|
||||
length <= self.context_size * self.subsampling_factor
|
||||
), "Context size is too short for current length: %d versus %d" % (
|
||||
length,
|
||||
self.context_size * self.subsampling_factor,
|
||||
)
|
||||
mask = make_source_mask(x_len).to(x.device)
|
||||
x, mask = self.embed(x, mask, None)
|
||||
x = self.embed_norm(x)
|
||||
olens = mask.eq(0).sum(1)
|
||||
|
||||
for block in self.rwkv_blocks:
|
||||
x, _ = block(x)
|
||||
# for streaming inference
|
||||
# xs_pad = self.rwkv_infer(xs_pad)
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
if self.time_reduction_factor > 1:
|
||||
x = x[:,::self.time_reduction_factor,:]
|
||||
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
|
||||
|
||||
return x, olens, None
|
||||
|
||||
def rwkv_infer(self, xs_pad):
|
||||
|
||||
batch_size = xs_pad.shape[0]
|
||||
|
||||
hidden_sizes = [
|
||||
self._output_size for i in range(5)
|
||||
]
|
||||
|
||||
state = [
|
||||
torch.zeros(
|
||||
(batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
state[4] -= 1e-30
|
||||
|
||||
xs_out = []
|
||||
for t in range(xs_pad.shape[1]):
|
||||
x_t = xs_pad[:,t,:]
|
||||
for idx, block in enumerate(self.rwkv_blocks):
|
||||
x_t, state = block(x_t, state=state)
|
||||
xs_out.append(x_t)
|
||||
xs_out = torch.stack(xs_out, dim=1)
|
||||
return xs_out
|
||||
0
funasr/models/whisper_models/__init__.py
Normal file
0
funasr/models/whisper_models/__init__.py
Normal file
145
funasr/modules/rwkv.py
Normal file
145
funasr/modules/rwkv.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Receptance Weighted Key Value (RWKV) block definition.
|
||||
|
||||
Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
|
||||
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention
|
||||
from funasr.modules.rwkv_feed_forward import FeedForward
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
|
||||
class RWKV(torch.nn.Module):
|
||||
"""RWKV module.
|
||||
|
||||
Args:
|
||||
size: Input/Output size.
|
||||
linear_size: Feed-forward hidden size.
|
||||
attention_size: SelfAttention hidden size.
|
||||
context_size: Context size for WKV computation.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
normalization_class: Normalization layer class.
|
||||
normalization_args: Normalization layer arguments.
|
||||
att_dropout_rate: Dropout rate for the attention module.
|
||||
ffn_dropout_rate: Dropout rate for the feed-forward module.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
linear_size: int,
|
||||
attention_size: int,
|
||||
context_size: int,
|
||||
block_id: int,
|
||||
num_blocks: int,
|
||||
att_dropout_rate: float = 0.0,
|
||||
ffn_dropout_rate: float = 0.0,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""Construct a RWKV object."""
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm_att = LayerNorm(size)
|
||||
self.layer_norm_ffn = LayerNorm(size)
|
||||
|
||||
self.att = EncoderSelfAttention(
|
||||
size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
|
||||
)
|
||||
self.dropout_att = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
|
||||
self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Compute receptance weighted key value.
|
||||
|
||||
Args:
|
||||
x: RWKV input sequences. (B, L, size)
|
||||
state: Decoder hidden states. [5 x (B, D_att/size, N)]
|
||||
|
||||
Returns:
|
||||
x: RWKV output sequences. (B, L, size)
|
||||
x: Decoder hidden states. [5 x (B, D_att/size, N)]
|
||||
|
||||
"""
|
||||
att, state = self.att(self.layer_norm_att(x), state=state)
|
||||
x = x + self.dropout_att(att)
|
||||
ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
|
||||
x = x + self.dropout_ffn(ffn)
|
||||
return x, state
|
||||
|
||||
class RWKVDecoderLayer(torch.nn.Module):
|
||||
"""RWKV module.
|
||||
|
||||
Args:
|
||||
size: Input/Output size.
|
||||
linear_size: Feed-forward hidden size.
|
||||
attention_size: SelfAttention hidden size.
|
||||
context_size: Context size for WKV computation.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
normalization_class: Normalization layer class.
|
||||
normalization_args: Normalization layer arguments.
|
||||
att_dropout_rate: Dropout rate for the attention module.
|
||||
ffn_dropout_rate: Dropout rate for the feed-forward module.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
linear_size: int,
|
||||
attention_size: int,
|
||||
context_size: int,
|
||||
block_id: int,
|
||||
num_blocks: int,
|
||||
att_dropout_rate: float = 0.0,
|
||||
ffn_dropout_rate: float = 0.0,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""Construct a RWKV object."""
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm_att = LayerNorm(size)
|
||||
self.layer_norm_ffn = LayerNorm(size)
|
||||
|
||||
self.att = DecoderSelfAttention(
|
||||
size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
|
||||
)
|
||||
self.dropout_att = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
|
||||
self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Compute receptance weighted key value.
|
||||
|
||||
Args:
|
||||
x: RWKV input sequences. (B, L, size)
|
||||
state: Decoder hidden states. [5 x (B, D_att/size, N)]
|
||||
|
||||
Returns:
|
||||
x: RWKV output sequences. (B, L, size)
|
||||
x: Decoder hidden states. [5 x (B, D_att/size, N)]
|
||||
|
||||
"""
|
||||
att, state = self.att(self.layer_norm_att(x), state=state)
|
||||
x = x + self.dropout_att(att)
|
||||
|
||||
ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
|
||||
x = x + self.dropout_ffn(ffn)
|
||||
|
||||
return x, state
|
||||
632
funasr/modules/rwkv_attention.py
Normal file
632
funasr/modules/rwkv_attention.py
Normal file
@ -0,0 +1,632 @@
|
||||
"""Attention (time mixing) modules for RWKV block.
|
||||
|
||||
Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
|
||||
|
||||
Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
|
||||
|
||||
""" # noqa
|
||||
|
||||
import math
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
wkv_kernel_encoder = None
|
||||
wkv_kernel_decoder = None
|
||||
|
||||
class WKVLinearAttentionEncoder(torch.autograd.Function):
|
||||
"""WKVLinearAttention function definition."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
time_decay: torch.Tensor,
|
||||
time_first: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.tensor,
|
||||
) -> torch.Tensor:
|
||||
"""WKVLinearAttention function forward pass.
|
||||
|
||||
Args:
|
||||
time_decay: Channel-wise time decay vector. (D_att)
|
||||
time_first: Channel-wise time first vector. (D_att)
|
||||
key: Key tensor. (B, U, D_att)
|
||||
value: Value tensor. (B, U, D_att)
|
||||
|
||||
Returns:
|
||||
out: Weighted Key-Value tensor. (B, U, D_att)
|
||||
|
||||
"""
|
||||
batch, length, dim = key.size()
|
||||
|
||||
assert length <= wkv_kernel_encoder.context_size, (
|
||||
f"Cannot process key of length {length} while context_size "
|
||||
f"is ({wkv_kernel_encoder.context_size}). Limit should be increased."
|
||||
)
|
||||
|
||||
assert batch * dim % min(dim, 32) == 0, (
|
||||
f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
|
||||
f"{min(dim, 32)}"
|
||||
)
|
||||
|
||||
ctx.input_dtype = key.dtype
|
||||
|
||||
time_decay = -torch.exp(time_decay.float().contiguous())
|
||||
time_first = time_first.float().contiguous()
|
||||
|
||||
key = key.float().contiguous()
|
||||
value = value.float().contiguous()
|
||||
|
||||
out = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
|
||||
wkv_kernel_encoder.forward(time_decay, time_first, key, value, out)
|
||||
ctx.save_for_backward(time_decay, time_first, key, value, out)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, grad_output: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""WKVLinearAttention function backward pass.
|
||||
|
||||
Args:
|
||||
grad_output: Output gradient. (B, U, D_att)
|
||||
|
||||
Returns:
|
||||
grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
|
||||
grad_time_first: Gradient for channel-wise time first vector. (D_att)
|
||||
grad_key: Gradient for key tensor. (B, U, D_att)
|
||||
grad_value: Gradient for value tensor. (B, U, D_att)
|
||||
|
||||
"""
|
||||
time_decay, time_first, key, value, output = ctx.saved_tensors
|
||||
grad_dtype = ctx.input_dtype
|
||||
|
||||
batch, _, dim = key.size()
|
||||
|
||||
grad_time_decay = torch.empty(
|
||||
(batch, dim),
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=time_decay.dtype,
|
||||
device=time_decay.device,
|
||||
)
|
||||
|
||||
grad_time_first = torch.empty(
|
||||
(batch, dim),
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=time_decay.dtype,
|
||||
device=time_decay.device,
|
||||
)
|
||||
|
||||
grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
||||
|
||||
wkv_kernel_encoder.backward(
|
||||
time_decay,
|
||||
time_first,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
grad_output.contiguous(),
|
||||
grad_time_decay,
|
||||
grad_time_first,
|
||||
grad_key,
|
||||
grad_value,
|
||||
)
|
||||
|
||||
grad_time_decay = torch.sum(grad_time_decay, dim=0)
|
||||
grad_time_first = torch.sum(grad_time_first, dim=0)
|
||||
|
||||
return (
|
||||
grad_time_decay,
|
||||
grad_time_first,
|
||||
grad_key,
|
||||
grad_value,
|
||||
)
|
||||
|
||||
class WKVLinearAttentionDecoder(torch.autograd.Function):
|
||||
"""WKVLinearAttention function definition."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
time_decay: torch.Tensor,
|
||||
time_first: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.tensor,
|
||||
) -> torch.Tensor:
|
||||
"""WKVLinearAttention function forward pass.
|
||||
|
||||
Args:
|
||||
time_decay: Channel-wise time decay vector. (D_att)
|
||||
time_first: Channel-wise time first vector. (D_att)
|
||||
key: Key tensor. (B, U, D_att)
|
||||
value: Value tensor. (B, U, D_att)
|
||||
|
||||
Returns:
|
||||
out: Weighted Key-Value tensor. (B, U, D_att)
|
||||
|
||||
"""
|
||||
batch, length, dim = key.size()
|
||||
|
||||
assert length <= wkv_kernel_decoder.context_size, (
|
||||
f"Cannot process key of length {length} while context_size "
|
||||
f"is ({wkv_kernel.context_size}). Limit should be increased."
|
||||
)
|
||||
|
||||
assert batch * dim % min(dim, 32) == 0, (
|
||||
f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
|
||||
f"{min(dim, 32)}"
|
||||
)
|
||||
|
||||
ctx.input_dtype = key.dtype
|
||||
|
||||
time_decay = -torch.exp(time_decay.float().contiguous())
|
||||
time_first = time_first.float().contiguous()
|
||||
|
||||
key = key.float().contiguous()
|
||||
value = value.float().contiguous()
|
||||
|
||||
out = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
|
||||
wkv_kernel_decoder.forward(time_decay, time_first, key, value, out)
|
||||
ctx.save_for_backward(time_decay, time_first, key, value, out)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, grad_output: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""WKVLinearAttention function backward pass.
|
||||
|
||||
Args:
|
||||
grad_output: Output gradient. (B, U, D_att)
|
||||
|
||||
Returns:
|
||||
grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
|
||||
grad_time_first: Gradient for channel-wise time first vector. (D_att)
|
||||
grad_key: Gradient for key tensor. (B, U, D_att)
|
||||
grad_value: Gradient for value tensor. (B, U, D_att)
|
||||
|
||||
"""
|
||||
time_decay, time_first, key, value, output = ctx.saved_tensors
|
||||
grad_dtype = ctx.input_dtype
|
||||
|
||||
batch, _, dim = key.size()
|
||||
|
||||
grad_time_decay = torch.empty(
|
||||
(batch, dim),
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=time_decay.dtype,
|
||||
device=time_decay.device,
|
||||
)
|
||||
|
||||
grad_time_first = torch.empty(
|
||||
(batch, dim),
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=time_decay.dtype,
|
||||
device=time_decay.device,
|
||||
)
|
||||
|
||||
grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
||||
|
||||
wkv_kernel_decoder.backward(
|
||||
time_decay,
|
||||
time_first,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
grad_output.contiguous(),
|
||||
grad_time_decay,
|
||||
grad_time_first,
|
||||
grad_key,
|
||||
grad_value,
|
||||
)
|
||||
|
||||
grad_time_decay = torch.sum(grad_time_decay, dim=0)
|
||||
grad_time_first = torch.sum(grad_time_first, dim=0)
|
||||
|
||||
return (
|
||||
grad_time_decay,
|
||||
grad_time_first,
|
||||
grad_key,
|
||||
grad_value,
|
||||
)
|
||||
|
||||
def load_encoder_wkv_kernel(context_size: int) -> None:
|
||||
"""Load WKV CUDA kernel.
|
||||
|
||||
Args:
|
||||
context_size: Context size.
|
||||
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
global wkv_kernel_encoder
|
||||
|
||||
if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size:
|
||||
return
|
||||
|
||||
if find_spec("ninja") is None:
|
||||
raise ImportError(
|
||||
"Ninja package was not found. WKV kernel module can't be loaded "
|
||||
"for training. Please, 'pip install ninja' in your environment."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise ImportError(
|
||||
"CUDA is currently a requirement for WKV kernel loading. "
|
||||
"Please set your devices properly and launch again."
|
||||
)
|
||||
|
||||
kernel_folder = Path(__file__).resolve().parent / "cuda_encoder"
|
||||
kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
|
||||
|
||||
kernel_cflags = [
|
||||
"-res-usage",
|
||||
"--maxrregcount 60",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
f"-DTmax={context_size}",
|
||||
]
|
||||
wkv_kernel_encoder = load(
|
||||
name=f"encoder_wkv_{context_size}",
|
||||
sources=kernel_files,
|
||||
verbose=True,
|
||||
extra_cuda_cflags=kernel_cflags,
|
||||
)
|
||||
wkv_kernel_encoder.context_size = context_size
|
||||
|
||||
def load_decoder_wkv_kernel(context_size: int) -> None:
|
||||
"""Load WKV CUDA kernel.
|
||||
|
||||
Args:
|
||||
context_size: Context size.
|
||||
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
global wkv_kernel_decoder
|
||||
|
||||
if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size:
|
||||
return
|
||||
|
||||
if find_spec("ninja") is None:
|
||||
raise ImportError(
|
||||
"Ninja package was not found. WKV kernel module can't be loaded "
|
||||
"for training. Please, 'pip install ninja' in your environment."
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise ImportError(
|
||||
"CUDA is currently a requirement for WKV kernel loading. "
|
||||
"Please set your devices properly and launch again."
|
||||
)
|
||||
|
||||
kernel_folder = Path(__file__).resolve().parent / "cuda_decoder"
|
||||
kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
|
||||
|
||||
kernel_cflags = [
|
||||
"-res-usage",
|
||||
"--maxrregcount 60",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
f"-DTmax={context_size}",
|
||||
]
|
||||
wkv_kernel_decoder = load(
|
||||
name=f"decoder_wkv_{context_size}",
|
||||
sources=kernel_files,
|
||||
verbose=True,
|
||||
extra_cuda_cflags=kernel_cflags,
|
||||
)
|
||||
wkv_kernel_decoder.context_size = context_size
|
||||
|
||||
class SelfAttention(torch.nn.Module):
|
||||
"""SelfAttention module definition.
|
||||
|
||||
Args:
|
||||
size: Input/Output size.
|
||||
attention_size: Attention hidden size.
|
||||
context_size: Context size for WKV kernel.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
attention_size: int,
|
||||
block_id: int,
|
||||
dropout_rate: float,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
"""Construct a SelfAttention object."""
|
||||
super().__init__()
|
||||
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
|
||||
self.time_first = torch.nn.Parameter(torch.empty(attention_size))
|
||||
|
||||
self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
|
||||
self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
|
||||
self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
|
||||
|
||||
self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
|
||||
self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
|
||||
self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
|
||||
|
||||
self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
|
||||
|
||||
self.block_id = block_id
|
||||
|
||||
self.reset_parameters(size, attention_size, block_id, num_blocks)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
def reset_parameters(
|
||||
self, size: int, attention_size: int, block_id: int, num_blocks: int
|
||||
) -> None:
|
||||
"""Reset module parameters.
|
||||
|
||||
Args:
|
||||
size: Block size.
|
||||
attention_size: Attention hidden size.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
|
||||
"""
|
||||
ratio_0_to_1 = block_id / (num_blocks - 1)
|
||||
ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
|
||||
|
||||
time_weight = torch.ones(1, 1, size)
|
||||
|
||||
for i in range(size):
|
||||
time_weight[0, 0, i] = i / size
|
||||
|
||||
decay_speed = [
|
||||
-5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
for h in range(attention_size)
|
||||
]
|
||||
decay_speed = torch.tensor(
|
||||
decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
|
||||
)
|
||||
|
||||
zigzag = (
|
||||
torch.tensor(
|
||||
[(i + 1) % 3 - 1 for i in range(attention_size)],
|
||||
dtype=self.time_first.dtype,
|
||||
device=self.time_first.device,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.time_decay.data = decay_speed
|
||||
self.time_first.data = torch.ones_like(
|
||||
self.time_first * math.log(0.3) + zigzag
|
||||
)
|
||||
|
||||
self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
||||
self.time_mix_value.data = (
|
||||
torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
)
|
||||
self.time_mix_receptance.data = torch.pow(
|
||||
time_weight, 0.5 * ratio_1_to_almost0
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def wkv_linear_attention(
|
||||
self,
|
||||
time_decay: torch.Tensor,
|
||||
time_first: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Compute WKV with state (i.e.: for inference).
|
||||
|
||||
Args:
|
||||
time_decay: Channel-wise time decay vector. (D_att)
|
||||
time_first: Channel-wise time first vector. (D_att)
|
||||
key: Key tensor. (B, 1, D_att)
|
||||
value: Value tensor. (B, 1, D_att)
|
||||
state: Decoder hidden states. [3 x (B, D_att)]
|
||||
|
||||
Returns:
|
||||
output: Weighted Key-Value. (B, 1, D_att)
|
||||
state: Decoder hidden states. [3 x (B, 1, D_att)]
|
||||
|
||||
"""
|
||||
num_state, den_state, max_state = state
|
||||
|
||||
max_for_output = torch.maximum(max_state, (time_first + key))
|
||||
|
||||
e1 = torch.exp(max_state - max_for_output)
|
||||
e2 = torch.exp((time_first + key) - max_for_output)
|
||||
|
||||
numerator = e1 * num_state + e2 * value
|
||||
denominator = e1 * den_state + e2
|
||||
|
||||
max_for_state = torch.maximum(key, (max_state + time_decay))
|
||||
|
||||
e1 = torch.exp((max_state + time_decay) - max_for_state)
|
||||
e2 = torch.exp(key - max_for_state)
|
||||
|
||||
wkv = numerator / denominator
|
||||
|
||||
state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
|
||||
|
||||
return wkv, state
|
||||
|
||||
|
||||
class DecoderSelfAttention(SelfAttention):
|
||||
"""SelfAttention module definition.
|
||||
|
||||
Args:
|
||||
size: Input/Output size.
|
||||
attention_size: Attention hidden size.
|
||||
context_size: Context size for WKV kernel.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
attention_size: int,
|
||||
context_size: int,
|
||||
block_id: int,
|
||||
dropout_rate: float,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
"""Construct a SelfAttention object."""
|
||||
super().__init__(
|
||||
size,
|
||||
attention_size,
|
||||
block_id,
|
||||
dropout_rate,
|
||||
num_blocks
|
||||
)
|
||||
load_decoder_wkv_kernel(context_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
||||
"""Compute time mixing.
|
||||
|
||||
Args:
|
||||
x: SelfAttention input sequences. (B, U, size)
|
||||
state: Decoder hidden states. [5 x (B, 1, D_att, N)]
|
||||
|
||||
Returns:
|
||||
x: SelfAttention output sequences. (B, U, size)
|
||||
|
||||
"""
|
||||
shifted_x = (
|
||||
self.time_shift(x) if state is None else state[1][..., self.block_id]
|
||||
)
|
||||
|
||||
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
|
||||
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
|
||||
receptance = x * self.time_mix_receptance + shifted_x * (
|
||||
1 - self.time_mix_receptance
|
||||
)
|
||||
|
||||
key = self.proj_key(key)
|
||||
value = self.proj_value(value)
|
||||
receptance = torch.sigmoid(self.proj_receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[1][..., self.block_id] = x
|
||||
|
||||
wkv, att_state = self.wkv_linear_attention(
|
||||
self.time_decay,
|
||||
self.time_first,
|
||||
key,
|
||||
value,
|
||||
tuple(s[..., self.block_id] for s in state[2:]),
|
||||
)
|
||||
|
||||
state[2][..., self.block_id] = att_state[0]
|
||||
state[3][..., self.block_id] = att_state[1]
|
||||
state[4][..., self.block_id] = att_state[2]
|
||||
else:
|
||||
wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value)
|
||||
|
||||
wkv = self.dropout(wkv)
|
||||
x = self.proj_output(receptance * wkv)
|
||||
|
||||
return x, state
|
||||
|
||||
class EncoderSelfAttention(SelfAttention):
|
||||
"""SelfAttention module definition.
|
||||
|
||||
Args:
|
||||
size: Input/Output size.
|
||||
attention_size: Attention hidden size.
|
||||
context_size: Context size for WKV kernel.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
attention_size: int,
|
||||
context_size: int,
|
||||
block_id: int,
|
||||
dropout_rate: float,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
"""Construct a SelfAttention object."""
|
||||
super().__init__(
|
||||
size,
|
||||
attention_size,
|
||||
block_id,
|
||||
dropout_rate,
|
||||
num_blocks
|
||||
)
|
||||
load_encoder_wkv_kernel(context_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
||||
"""Compute time mixing.
|
||||
|
||||
Args:
|
||||
x: SelfAttention input sequences. (B, U, size)
|
||||
state: Decoder hidden states. [5 x (B, 1, D_att, N)]
|
||||
|
||||
Returns:
|
||||
x: SelfAttention output sequences. (B, U, size)
|
||||
|
||||
"""
|
||||
shifted_x = (
|
||||
self.time_shift(x) if state is None else state[1][..., self.block_id]
|
||||
)
|
||||
|
||||
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
|
||||
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
|
||||
receptance = x * self.time_mix_receptance + shifted_x * (
|
||||
1 - self.time_mix_receptance
|
||||
)
|
||||
|
||||
key = self.proj_key(key)
|
||||
value = self.proj_value(value)
|
||||
receptance = torch.sigmoid(self.proj_receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[1][..., self.block_id] = x
|
||||
|
||||
wkv, att_state = self.wkv_linear_attention(
|
||||
self.time_decay,
|
||||
self.time_first,
|
||||
key,
|
||||
value,
|
||||
tuple(s[..., self.block_id] for s in state[2:]),
|
||||
)
|
||||
|
||||
state[2][..., self.block_id] = att_state[0]
|
||||
state[3][..., self.block_id] = att_state[1]
|
||||
state[4][..., self.block_id] = att_state[2]
|
||||
else:
|
||||
wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value)
|
||||
|
||||
wkv = self.dropout(wkv)
|
||||
x = self.proj_output(receptance * wkv)
|
||||
|
||||
return x, state
|
||||
|
||||
97
funasr/modules/rwkv_feed_forward.py
Normal file
97
funasr/modules/rwkv_feed_forward.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Feed-forward (channel mixing) module for RWKV block.
|
||||
|
||||
Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
|
||||
|
||||
Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
|
||||
|
||||
""" # noqa
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class FeedForward(torch.nn.Module):
|
||||
"""FeedForward module definition.
|
||||
|
||||
Args:
|
||||
size: Input/Output size.
|
||||
hidden_size: Hidden size.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int
|
||||
) -> None:
|
||||
"""Construct a FeedForward object."""
|
||||
super().__init__()
|
||||
|
||||
self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
|
||||
self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
|
||||
|
||||
self.proj_key = torch.nn.Linear(size, hidden_size, bias=True)
|
||||
self.proj_value = torch.nn.Linear(hidden_size, size, bias=True)
|
||||
self.proj_receptance = torch.nn.Linear(size, size, bias=True)
|
||||
|
||||
self.block_id = block_id
|
||||
|
||||
self.reset_parameters(size, block_id, num_blocks)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None:
|
||||
"""Reset module parameters.
|
||||
|
||||
Args:
|
||||
size: Block size.
|
||||
block_id: Block index.
|
||||
num_blocks: Number of blocks in the architecture.
|
||||
|
||||
"""
|
||||
ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
|
||||
|
||||
time_weight = torch.ones(1, 1, size)
|
||||
|
||||
for i in range(size):
|
||||
time_weight[0, 0, i] = i / size
|
||||
|
||||
with torch.no_grad():
|
||||
self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
||||
self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None
|
||||
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
||||
"""Compute channel mixing.
|
||||
|
||||
Args:
|
||||
x: FeedForward input sequences. (B, U, size)
|
||||
state: Decoder hidden state. [5 x (B, 1, size, N)]
|
||||
|
||||
Returns:
|
||||
x: FeedForward output sequences. (B, U, size)
|
||||
state: Decoder hidden state. [5 x (B, 1, size, N)]
|
||||
|
||||
"""
|
||||
shifted_x = (
|
||||
self.time_shift(x) if state is None else state[0][..., self.block_id]
|
||||
)
|
||||
|
||||
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
|
||||
receptance = x * self.time_mix_receptance + shifted_x * (
|
||||
1 - self.time_mix_receptance
|
||||
)
|
||||
|
||||
key = torch.square(torch.relu(self.proj_key(key)))
|
||||
value = self.proj_value(self.dropout(key))
|
||||
receptance = torch.sigmoid(self.proj_receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[0][..., self.block_id] = x
|
||||
|
||||
x = receptance * value
|
||||
|
||||
return x, state
|
||||
190
funasr/modules/rwkv_subsampling.py
Normal file
190
funasr/modules/rwkv_subsampling.py
Normal file
@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
import logging
|
||||
from funasr.modules.streaming_utils.utils import sequence_mask
|
||||
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
|
||||
from typing import Optional, Tuple, Union
|
||||
import math
|
||||
|
||||
class TooShortUttError(Exception):
|
||||
"""Raised when the utt is too short for subsampling.
|
||||
|
||||
Args:
|
||||
message (str): Message for error catch
|
||||
actual_size (int): the short size that cannot pass the subsampling
|
||||
limit (int): the limit size for subsampling
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message, actual_size, limit):
|
||||
"""Construct a TooShortUttError for error handler."""
|
||||
super().__init__(message)
|
||||
self.actual_size = actual_size
|
||||
self.limit = limit
|
||||
|
||||
|
||||
def check_short_utt(ins, size):
|
||||
"""Check if the utterance is too short for subsampling."""
|
||||
if isinstance(ins, Conv2dSubsampling2) and size < 3:
|
||||
return True, 3
|
||||
if isinstance(ins, Conv2dSubsampling) and size < 7:
|
||||
return True, 7
|
||||
if isinstance(ins, Conv2dSubsampling6) and size < 11:
|
||||
return True, 11
|
||||
if isinstance(ins, Conv2dSubsampling8) and size < 15:
|
||||
return True, 15
|
||||
return False, -1
|
||||
|
||||
|
||||
class RWKVConvInput(torch.nn.Module):
|
||||
"""Streaming ConvInput module definition.
|
||||
Args:
|
||||
input_size: Input size.
|
||||
conv_size: Convolution size.
|
||||
subsampling_factor: Subsampling factor.
|
||||
output_size: Block output dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
conv_size: Union[int, Tuple],
|
||||
subsampling_factor: int = 4,
|
||||
conv_kernel_size: int = 3,
|
||||
output_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Construct a ConvInput object."""
|
||||
super().__init__()
|
||||
if subsampling_factor == 1:
|
||||
conv_size1, conv_size2, conv_size3 = conv_size
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
output_proj = conv_size3 * ((input_size // 2) // 2)
|
||||
|
||||
self.subsampling_factor = 1
|
||||
|
||||
self.stride_1 = 1
|
||||
|
||||
self.create_new_mask = self.create_new_vgg_mask
|
||||
|
||||
else:
|
||||
conv_size1, conv_size2, conv_size3 = conv_size
|
||||
|
||||
kernel_1 = int(subsampling_factor / 2)
|
||||
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
|
||||
output_proj = conv_size3 * ((input_size // 2) // 2)
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
|
||||
self.create_new_mask = self.create_new_vgg_mask
|
||||
|
||||
self.stride_1 = kernel_1
|
||||
|
||||
self.min_frame_length = 7
|
||||
|
||||
if output_size is not None:
|
||||
self.output = torch.nn.Linear(output_proj, output_size)
|
||||
self.output_size = output_size
|
||||
else:
|
||||
self.output = None
|
||||
self.output_size = output_proj
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode input sequences.
|
||||
Args:
|
||||
x: ConvInput input sequences. (B, T, D_feats)
|
||||
mask: Mask of input sequences. (B, 1, T)
|
||||
Returns:
|
||||
x: ConvInput output sequences. (B, sub(T), D_out)
|
||||
mask: Mask of output sequences. (B, 1, sub(T))
|
||||
"""
|
||||
if mask is not None:
|
||||
mask = self.create_new_mask(mask)
|
||||
olens = max(mask.eq(0).sum(1))
|
||||
|
||||
b, t, f = x.size()
|
||||
x = x.unsqueeze(1) # (b. 1. t. f)
|
||||
|
||||
if chunk_size is not None:
|
||||
max_input_length = int(
|
||||
chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
|
||||
)
|
||||
x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
|
||||
x = list(x)
|
||||
x = torch.stack(x, dim=0)
|
||||
N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
|
||||
x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
_, c, _, f = x.size()
|
||||
if chunk_size is not None:
|
||||
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
|
||||
else:
|
||||
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
|
||||
|
||||
if self.output is not None:
|
||||
x = self.output(x)
|
||||
|
||||
return x, mask[:,:olens][:,:x.size(1)]
|
||||
|
||||
def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Create a new mask for VGG output sequences.
|
||||
Args:
|
||||
mask: Mask of input sequences. (B, T)
|
||||
Returns:
|
||||
mask: Mask of output sequences. (B, sub(T))
|
||||
"""
|
||||
if self.subsampling_factor > 1:
|
||||
return mask[:, ::2][:, ::self.stride_1]
|
||||
else:
|
||||
return mask
|
||||
|
||||
def get_size_before_subsampling(self, size: int) -> int:
|
||||
"""Return the original size before subsampling for a given size.
|
||||
Args:
|
||||
size: Number of frames after subsampling.
|
||||
Returns:
|
||||
: Number of frames before subsampling.
|
||||
"""
|
||||
return size * self.subsampling_factor
|
||||
@ -39,6 +39,8 @@ public class MainActivity extends AppCompatActivity {
|
||||
public static final String TAG = MainActivity.class.getSimpleName();
|
||||
// WebSocket地址
|
||||
public String ASR_HOST = "";
|
||||
// 官方WebSocket地址
|
||||
public static final String DEFAULT_HOST = "wss://101.37.77.25:10088";
|
||||
// 发送的JSON数据
|
||||
public static final String MODE = "2pass";
|
||||
public static final String CHUNK_SIZE = "5, 10, 5";
|
||||
@ -61,7 +63,6 @@ public class MainActivity extends AppCompatActivity {
|
||||
// 控件
|
||||
private Button recordBtn;
|
||||
private TextView resultText;
|
||||
private WebSocket webSocket;
|
||||
|
||||
@SuppressLint("ClickableViewAccessibility")
|
||||
@Override
|
||||
@ -106,8 +107,8 @@ public class MainActivity extends AppCompatActivity {
|
||||
ASR_HOST = uri;
|
||||
}
|
||||
// 读取热词
|
||||
String hotWords = sharedPreferences.getString("hotwords", "");
|
||||
if (!hotWords.equals("")) {
|
||||
String hotWords = sharedPreferences.getString("hotwords", null);
|
||||
if (hotWords != null) {
|
||||
this.hotWords = hotWords;
|
||||
}
|
||||
}
|
||||
@ -150,6 +151,14 @@ public class MainActivity extends AppCompatActivity {
|
||||
editor.apply();
|
||||
}
|
||||
});
|
||||
builder.setNeutralButton("使用官方服务", (dialog, id) -> {
|
||||
ASR_HOST = DEFAULT_HOST;
|
||||
input.setText(DEFAULT_HOST);
|
||||
Toast.makeText(MainActivity.this, "WebSocket地址:" + ASR_HOST, Toast.LENGTH_SHORT).show();
|
||||
SharedPreferences.Editor editor = sharedPreferences.edit();
|
||||
editor.putString("uri", ASR_HOST);
|
||||
editor.apply();
|
||||
});
|
||||
AlertDialog dialog = builder.create();
|
||||
dialog.show();
|
||||
}
|
||||
@ -166,12 +175,10 @@ public class MainActivity extends AppCompatActivity {
|
||||
builder.setView(view);
|
||||
builder.setPositiveButton("确定", (dialog, id) -> {
|
||||
String hotwords = input.getText().toString();
|
||||
if (!hotwords.equals("")) {
|
||||
this.hotWords = hotwords;
|
||||
SharedPreferences.Editor editor = sharedPreferences.edit();
|
||||
editor.putString("hotwords", hotwords);
|
||||
editor.apply();
|
||||
}
|
||||
this.hotWords = hotwords;
|
||||
SharedPreferences.Editor editor = sharedPreferences.edit();
|
||||
editor.putString("hotwords", hotwords);
|
||||
editor.apply();
|
||||
});
|
||||
AlertDialog dialog = builder.create();
|
||||
dialog.show();
|
||||
@ -225,7 +232,7 @@ public class MainActivity extends AppCompatActivity {
|
||||
Request request = new Request.Builder()
|
||||
.url(ASR_HOST)
|
||||
.build();
|
||||
webSocket = client.newWebSocket(request, new WebSocketListener() {
|
||||
WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() {
|
||||
|
||||
@Override
|
||||
public void onOpen(@NonNull WebSocket webSocket, @NonNull Response response) {
|
||||
@ -311,7 +318,9 @@ public class MainActivity extends AppCompatActivity {
|
||||
obj.put("chunk_size", array);
|
||||
obj.put("chunk_interval", CHUNK_INTERVAL);
|
||||
obj.put("wav_name", "default");
|
||||
obj.put("hotwords", hotWords);
|
||||
if (!hotWords.equals("")) {
|
||||
obj.put("hotwords", hotWords);
|
||||
}
|
||||
obj.put("wav_format", "pcm");
|
||||
obj.put("is_speaking", isSpeaking);
|
||||
return obj.toString();
|
||||
|
||||
@ -83,7 +83,8 @@ nohup bash run_server.sh \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key > log.out 2>&1 &
|
||||
--keyfile ../../../ssl_key/server.key \
|
||||
--hotword ../../hotwords.txt > log.out 2>&1 &
|
||||
```
|
||||
|
||||
Introduction to run_server.sh parameters:
|
||||
@ -102,6 +103,7 @@ Introduction to run_server.sh parameters:
|
||||
--io-thread-num: Number of IO threads that the server starts. Default is 1.
|
||||
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
|
||||
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.
|
||||
--hotword Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
|
||||
```
|
||||
|
||||
### Shutting Down the FunASR Service
|
||||
|
||||
@ -79,7 +79,7 @@ nohup bash run_server.sh \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key
|
||||
--keyfile ../../../ssl_key/server.key > log.out 2>&1 &
|
||||
```
|
||||
|
||||
Introduction to run_server.sh parameters:
|
||||
|
||||
@ -165,7 +165,8 @@ nohup bash run_server.sh \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key > log.out 2>&1 &
|
||||
--keyfile ../../../ssl_key/server.key \
|
||||
--hotword ../../hotwords.txt > log.out 2>&1 &
|
||||
```
|
||||
**run_server.sh命令参数介绍**
|
||||
```text
|
||||
@ -182,6 +183,7 @@ nohup bash run_server.sh \
|
||||
--io-thread-num 服务端启动的IO线程数,默认为 1
|
||||
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
|
||||
--keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
--hotword 热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt
|
||||
```
|
||||
|
||||
### 关闭FunASR服务
|
||||
|
||||
@ -72,7 +72,8 @@ nohup bash run_server_2pass.sh \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key > log.out 2>&1 &
|
||||
--keyfile ../../../ssl_key/server.key \
|
||||
--hotword ../../hotwords.txt > log.out 2>&1 &
|
||||
|
||||
# If you want to close ssl,please add:--certfile 0
|
||||
# If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model:
|
||||
@ -97,6 +98,7 @@ nohup bash run_server_2pass.sh \
|
||||
--io-thread-num: Number of IO threads that the server starts. Default is 1.
|
||||
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0
|
||||
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.
|
||||
--hotword Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
|
||||
```
|
||||
|
||||
### Shutting Down the FunASR Service
|
||||
|
||||
@ -31,7 +31,8 @@ nohup bash run_server_2pass.sh \
|
||||
--model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \
|
||||
--punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
|
||||
--itn-dir thuduj12/fst_itn_zh > log.out 2>&1 &
|
||||
--itn-dir thuduj12/fst_itn_zh \
|
||||
--hotwordsfile ../../hotwords.txt > log.out 2>&1 &
|
||||
|
||||
# 如果您想关闭ssl,增加参数:--certfile 0
|
||||
# 如果您想使用时间戳或者热词模型进行部署,请设置--model-dir为对应模型:
|
||||
@ -80,7 +81,8 @@ nohup bash run_server_2pass.sh \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key > log.out 2>&1 &
|
||||
--keyfile ../../../ssl_key/server.key \
|
||||
--hotword ../../hotwords.txt > log.out 2>&1 &
|
||||
```
|
||||
**run_server_2pass.sh命令参数介绍**
|
||||
```text
|
||||
@ -98,6 +100,7 @@ nohup bash run_server_2pass.sh \
|
||||
--io-thread-num 服务端启动的IO线程数,默认为 1
|
||||
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0
|
||||
--keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
--hotword 热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt
|
||||
```
|
||||
|
||||
### 关闭FunASR服务
|
||||
|
||||
@ -3,7 +3,7 @@ import numpy as np
|
||||
|
||||
def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0, total_offset=-1.5):
|
||||
if not len(char_list):
|
||||
return []
|
||||
return '', []
|
||||
START_END_THRESHOLD = 5
|
||||
MAX_TOKEN_DURATION = 30
|
||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||
|
||||
@ -9,6 +9,7 @@ io_thread_num=8
|
||||
port=10095
|
||||
certfile="../../../ssl_key/server.crt"
|
||||
keyfile="../../../ssl_key/server.key"
|
||||
hotwordsfile="../../hotwords.txt"
|
||||
|
||||
. ../../egs/aishell/transformer/utils/parse_options.sh || exit 1;
|
||||
|
||||
@ -24,7 +25,8 @@ if [ -z "$certfile" ] || [ "$certfile" -eq 0 ]; then
|
||||
--io-thread-num ${io_thread_num} \
|
||||
--port ${port} \
|
||||
--certfile "" \
|
||||
--keyfile ""
|
||||
--keyfile "" \
|
||||
--hotwordsfile ${hotwordsfile}
|
||||
else
|
||||
./funasr-wss-server \
|
||||
--download-model-dir ${download_model_dir} \
|
||||
@ -36,5 +38,6 @@ else
|
||||
--io-thread-num ${io_thread_num} \
|
||||
--port ${port} \
|
||||
--certfile ${certfile} \
|
||||
--keyfile ${keyfile}
|
||||
--keyfile ${keyfile} \
|
||||
--hotwordsfile ${hotwordsfile}
|
||||
fi
|
||||
|
||||
@ -10,6 +10,7 @@ io_thread_num=8
|
||||
port=10095
|
||||
certfile="../../../ssl_key/server.crt"
|
||||
keyfile="../../../ssl_key/server.key"
|
||||
hotwordsfile="../../hotwords.txt"
|
||||
|
||||
. ../../egs/aishell/transformer/utils/parse_options.sh || exit 1;
|
||||
|
||||
@ -26,7 +27,8 @@ if [ -z "$certfile" ] || [ "$certfile" -eq 0 ]; then
|
||||
--io-thread-num ${io_thread_num} \
|
||||
--port ${port} \
|
||||
--certfile "" \
|
||||
--keyfile ""
|
||||
--keyfile "" \
|
||||
--hotwordsfile ${hotwordsfile}
|
||||
else
|
||||
./funasr-wss-server-2pass \
|
||||
--download-model-dir ${download_model_dir} \
|
||||
@ -39,5 +41,6 @@ else
|
||||
--io-thread-num ${io_thread_num} \
|
||||
--port ${port} \
|
||||
--certfile ${certfile} \
|
||||
--keyfile ${keyfile}
|
||||
--keyfile ${keyfile} \
|
||||
--hotwordsfile ${hotwordsfile}
|
||||
fi
|
||||
|
||||
@ -14,6 +14,9 @@
|
||||
#include <unistd.h>
|
||||
#include "websocket-server-2pass.h"
|
||||
|
||||
#include <fstream>
|
||||
std::string hotwords = "";
|
||||
|
||||
using namespace std;
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
std::map<std::string, std::string>& model_path) {
|
||||
@ -109,6 +112,15 @@ int main(int argc, char* argv[]) {
|
||||
"connection",
|
||||
false, "../../../ssl_key/server.key", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> hotwordsfile(
|
||||
"", "hotword",
|
||||
"default: ../../hotwords.txt, path of hotwordsfile"
|
||||
"connection",
|
||||
false, "../../hotwords.txt", "string");
|
||||
|
||||
// add file
|
||||
cmd.add(hotwordsfile);
|
||||
|
||||
cmd.add(certfile);
|
||||
cmd.add(keyfile);
|
||||
|
||||
@ -417,6 +429,21 @@ int main(int argc, char* argv[]) {
|
||||
std::string s_certfile = certfile.getValue();
|
||||
std::string s_keyfile = keyfile.getValue();
|
||||
|
||||
std::string s_hotwordsfile = hotwordsfile.getValue();
|
||||
std::string line;
|
||||
std::ifstream file(s_hotwordsfile);
|
||||
LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile;
|
||||
|
||||
if (file.is_open()) {
|
||||
while (getline(file, line)) {
|
||||
hotwords += line+HOTWORD_SEP;
|
||||
}
|
||||
LOG(INFO) << "hotwords: " << hotwords;
|
||||
file.close();
|
||||
} else {
|
||||
LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile;
|
||||
}
|
||||
|
||||
bool is_ssl = false;
|
||||
if (!s_certfile.empty()) {
|
||||
is_ssl = true;
|
||||
@ -460,8 +487,7 @@ int main(int argc, char* argv[]) {
|
||||
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
|
||||
}
|
||||
|
||||
std::cout << "asr model init finished. listen on port:" << s_port
|
||||
<< std::endl;
|
||||
LOG(INFO) << "asr model init finished. listen on port:" << s_port;
|
||||
|
||||
// Start the ASIO network io_service run loop
|
||||
std::vector<std::thread> ts;
|
||||
@ -480,7 +506,7 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
|
||||
} catch (std::exception const& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
LOG(ERROR) << "Error: " << e.what();
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -13,6 +13,9 @@
|
||||
#include "websocket-server.h"
|
||||
#include <unistd.h>
|
||||
|
||||
#include <fstream>
|
||||
std::string hotwords = "";
|
||||
|
||||
using namespace std;
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
std::map<std::string, std::string>& model_path) {
|
||||
@ -95,6 +98,15 @@ int main(int argc, char* argv[]) {
|
||||
"default: ../../../ssl_key/server.key, path of keyfile for WSS connection",
|
||||
false, "../../../ssl_key/server.key", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> hotwordsfile(
|
||||
"", "hotword",
|
||||
"default: ../../hotwords.txt, path of hotwordsfile"
|
||||
"connection",
|
||||
false, "../../hotwords.txt", "string");
|
||||
|
||||
// add file
|
||||
cmd.add(hotwordsfile);
|
||||
|
||||
cmd.add(certfile);
|
||||
cmd.add(keyfile);
|
||||
|
||||
@ -331,6 +343,21 @@ int main(int argc, char* argv[]) {
|
||||
std::string s_certfile = certfile.getValue();
|
||||
std::string s_keyfile = keyfile.getValue();
|
||||
|
||||
std::string s_hotwordsfile = hotwordsfile.getValue();
|
||||
std::string line;
|
||||
std::ifstream file(s_hotwordsfile);
|
||||
LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile;
|
||||
|
||||
if (file.is_open()) {
|
||||
while (getline(file, line)) {
|
||||
hotwords += line+HOTWORD_SEP;
|
||||
}
|
||||
LOG(INFO) << "hotwords: " << hotwords;
|
||||
file.close();
|
||||
} else {
|
||||
LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile;
|
||||
}
|
||||
|
||||
bool is_ssl = false;
|
||||
if (!s_certfile.empty()) {
|
||||
is_ssl = true;
|
||||
@ -374,8 +401,7 @@ int main(int argc, char* argv[]) {
|
||||
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
|
||||
}
|
||||
|
||||
std::cout << "asr model init finished. listen on port:" << s_port
|
||||
<< std::endl;
|
||||
LOG(INFO) << "asr model init finished. listen on port:" << s_port;
|
||||
|
||||
// Start the ASIO network io_service run loop
|
||||
std::vector<std::thread> ts;
|
||||
@ -394,7 +420,7 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
|
||||
} catch (std::exception const& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
LOG(ERROR) << "Error: " << e.what();
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@ -15,7 +15,9 @@
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
|
||||
extern std::string hotwords;
|
||||
|
||||
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
||||
websocketpp::connection_hdl hdl,
|
||||
std::string& s_certfile,
|
||||
@ -354,7 +356,14 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text: {
|
||||
nlohmann::json jsonresult = nlohmann::json::parse(payload);
|
||||
nlohmann::json jsonresult;
|
||||
try{
|
||||
jsonresult = nlohmann::json::parse(payload);
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
break;
|
||||
}
|
||||
|
||||
if (jsonresult.contains("wav_name")) {
|
||||
msg_data->msg["wav_name"] = jsonresult["wav_name"];
|
||||
@ -370,17 +379,26 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
msg_data->msg["hotwords"] = jsonresult["hotwords"];
|
||||
if (!msg_data->msg["hotwords"].empty()) {
|
||||
std::string hw = msg_data->msg["hotwords"];
|
||||
LOG(INFO)<<"hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
|
||||
hw = hw + " " + hotwords;
|
||||
LOG(INFO) << "hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
|
||||
msg_data->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
||||
}
|
||||
}else{
|
||||
} else {
|
||||
if (hotwords.empty()) {
|
||||
std::string hw = "";
|
||||
LOG(INFO)<<"hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
|
||||
msg_data->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
||||
}else {
|
||||
std::string hw = hotwords;
|
||||
LOG(INFO) << "hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
|
||||
msg_data->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (jsonresult.contains("audio_fs")) {
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
extern std::string hotwords;
|
||||
|
||||
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
||||
websocketpp::connection_hdl hdl,
|
||||
std::string& s_certfile,
|
||||
@ -254,7 +256,15 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text: {
|
||||
nlohmann::json jsonresult = nlohmann::json::parse(payload);
|
||||
nlohmann::json jsonresult;
|
||||
try{
|
||||
jsonresult = nlohmann::json::parse(payload);
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
break;
|
||||
}
|
||||
|
||||
if (jsonresult["wav_name"] != nullptr) {
|
||||
msg_data->msg["wav_name"] = jsonresult["wav_name"];
|
||||
}
|
||||
@ -266,17 +276,26 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
msg_data->msg["hotwords"] = jsonresult["hotwords"];
|
||||
if (!msg_data->msg["hotwords"].empty()) {
|
||||
std::string hw = msg_data->msg["hotwords"];
|
||||
LOG(INFO)<<"hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
|
||||
hw = hw + " " + hotwords;
|
||||
LOG(INFO) << "hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hw);
|
||||
msg_data->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
||||
}
|
||||
}else{
|
||||
} else {
|
||||
if (hotwords.empty()) {
|
||||
std::string hw = "";
|
||||
LOG(INFO)<<"hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
|
||||
msg_data->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
||||
}else {
|
||||
std::string hw = hotwords;
|
||||
LOG(INFO) << "hotwords: " << hw;
|
||||
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
|
||||
msg_data->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (jsonresult.contains("audio_fs")) {
|
||||
|
||||
2
funasr/runtime/websocket/hotwords.txt
Normal file
2
funasr/runtime/websocket/hotwords.txt
Normal file
@ -0,0 +1,2 @@
|
||||
阿里巴巴
|
||||
通义实验室
|
||||
0
funasr/utils/whisper_utils/__init__.py
Normal file
0
funasr/utils/whisper_utils/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user