diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py index 5e9344469..fd47bd336 100644 --- a/funasr/build_utils/build_asr_model.py +++ b/funasr/build_utils/build_asr_model.py @@ -42,6 +42,7 @@ from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.branchformer_encoder import BranchformerEncoder from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder from funasr.models.encoder.transformer_encoder import TransformerEncoder +from funasr.models.encoder.rwkv_encoder import RWKVEncoder from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.default import MultiChannelFrontend from funasr.models.frontend.fused import FusedFrontends @@ -119,6 +120,7 @@ encoder_choices = ClassChoices( e_branchformer=EBranchformerEncoder, mfcca_enc=MFCCAEncoder, chunk_conformer=ConformerChunkEncoder, + rwkv=RWKVEncoder, ), default="rnn", ) diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py new file mode 100644 index 000000000..8a33520e9 --- /dev/null +++ b/funasr/models/encoder/rwkv_encoder.py @@ -0,0 +1,155 @@ +"""RWKV encoder definition for Transducer models.""" + +import math +from typing import Dict, List, Optional, Tuple + +import torch + +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.modules.rwkv import RWKV +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.rwkv_subsampling import RWKVConvInput +from funasr.modules.nets_utils import make_source_mask + +class RWKVEncoder(AbsEncoder): + """RWKV encoder module. + + Based on https://arxiv.org/pdf/2305.13048.pdf. + + Args: + vocab_size: Vocabulary size. + output_size: Input/Output size. + context_size: Context size for WKV computation. + linear_size: FeedForward hidden size. + attention_size: SelfAttention hidden size. + normalization_type: Normalization layer type. + normalization_args: Normalization layer arguments. + num_blocks: Number of RWKV blocks. + embed_dropout_rate: Dropout rate for embedding layer. + att_dropout_rate: Dropout rate for the attention module. + ffn_dropout_rate: Dropout rate for the feed-forward module. + """ + + def __init__( + self, + input_size: int, + output_size: int = 512, + context_size: int = 1024, + linear_size: Optional[int] = None, + attention_size: Optional[int] = None, + num_blocks: int = 4, + att_dropout_rate: float = 0.0, + ffn_dropout_rate: float = 0.0, + dropout_rate: float = 0.0, + subsampling_factor: int =4, + time_reduction_factor: int = 1, + kernel: int = 3, + ) -> None: + """Construct a RWKVEncoder object.""" + super().__init__() + + self.embed = RWKVConvInput( + input_size, + [output_size//4, output_size//2, output_size], + subsampling_factor, + conv_kernel_size=kernel, + output_size=output_size, + ) + + self.subsampling_factor = subsampling_factor + + linear_size = output_size * 4 if linear_size is None else linear_size + attention_size = output_size if attention_size is None else attention_size + + self.rwkv_blocks = torch.nn.ModuleList( + [ + RWKV( + output_size, + linear_size, + attention_size, + context_size, + block_id, + num_blocks, + att_dropout_rate=att_dropout_rate, + ffn_dropout_rate=ffn_dropout_rate, + dropout_rate=dropout_rate, + ) + for block_id in range(num_blocks) + ] + ) + + self.embed_norm = LayerNorm(output_size) + self.final_norm = LayerNorm(output_size) + + self._output_size = output_size + self.context_size = context_size + + self.num_blocks = num_blocks + self.time_reduction_factor = time_reduction_factor + + def output_size(self) -> int: + return self._output_size + + def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: + """Encode source label sequences. + + Args: + x: Encoder input sequences. (B, L) + + Returns: + out: Encoder output sequences. (B, U, D) + + """ + _, length, _ = x.size() + + assert ( + length <= self.context_size * self.subsampling_factor + ), "Context size is too short for current length: %d versus %d" % ( + length, + self.context_size * self.subsampling_factor, + ) + mask = make_source_mask(x_len).to(x.device) + x, mask = self.embed(x, mask, None) + x = self.embed_norm(x) + olens = mask.eq(0).sum(1) + + for block in self.rwkv_blocks: + x, _ = block(x) + # for streaming inference + # xs_pad = self.rwkv_infer(xs_pad) + + x = self.final_norm(x) + + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x, olens, None + + def rwkv_infer(self, xs_pad): + + batch_size = xs_pad.shape[0] + + hidden_sizes = [ + self._output_size for i in range(5) + ] + + state = [ + torch.zeros( + (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), + dtype=torch.float32, + device=self.device, + ) + for i in range(5) + ] + + state[4] -= 1e-30 + + xs_out = [] + for t in range(xs_pad.shape[1]): + x_t = xs_pad[:,t,:] + for idx, block in enumerate(self.rwkv_blocks): + x_t, state = block(x_t, state=state) + xs_out.append(x_t) + xs_out = torch.stack(xs_out, dim=1) + return xs_out diff --git a/funasr/models/whisper_models/__init__.py b/funasr/models/whisper_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/modules/rwkv.py b/funasr/modules/rwkv.py new file mode 100644 index 000000000..f020828f7 --- /dev/null +++ b/funasr/modules/rwkv.py @@ -0,0 +1,145 @@ +"""Receptance Weighted Key Value (RWKV) block definition. + +Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py + +""" + +from typing import Dict, Optional, Tuple + +import torch + +from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention +from funasr.modules.rwkv_feed_forward import FeedForward +from funasr.modules.layer_norm import LayerNorm + +class RWKV(torch.nn.Module): + """RWKV module. + + Args: + size: Input/Output size. + linear_size: Feed-forward hidden size. + attention_size: SelfAttention hidden size. + context_size: Context size for WKV computation. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + normalization_class: Normalization layer class. + normalization_args: Normalization layer arguments. + att_dropout_rate: Dropout rate for the attention module. + ffn_dropout_rate: Dropout rate for the feed-forward module. + + """ + + def __init__( + self, + size: int, + linear_size: int, + attention_size: int, + context_size: int, + block_id: int, + num_blocks: int, + att_dropout_rate: float = 0.0, + ffn_dropout_rate: float = 0.0, + dropout_rate: float = 0.0, + ) -> None: + """Construct a RWKV object.""" + super().__init__() + + self.layer_norm_att = LayerNorm(size) + self.layer_norm_ffn = LayerNorm(size) + + self.att = EncoderSelfAttention( + size, attention_size, context_size, block_id, att_dropout_rate, num_blocks + ) + self.dropout_att = torch.nn.Dropout(p=dropout_rate) + + self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks) + self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) + + def forward( + self, + x: torch.Tensor, + state: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Compute receptance weighted key value. + + Args: + x: RWKV input sequences. (B, L, size) + state: Decoder hidden states. [5 x (B, D_att/size, N)] + + Returns: + x: RWKV output sequences. (B, L, size) + x: Decoder hidden states. [5 x (B, D_att/size, N)] + + """ + att, state = self.att(self.layer_norm_att(x), state=state) + x = x + self.dropout_att(att) + ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) + x = x + self.dropout_ffn(ffn) + return x, state + +class RWKVDecoderLayer(torch.nn.Module): + """RWKV module. + + Args: + size: Input/Output size. + linear_size: Feed-forward hidden size. + attention_size: SelfAttention hidden size. + context_size: Context size for WKV computation. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + normalization_class: Normalization layer class. + normalization_args: Normalization layer arguments. + att_dropout_rate: Dropout rate for the attention module. + ffn_dropout_rate: Dropout rate for the feed-forward module. + + """ + + def __init__( + self, + size: int, + linear_size: int, + attention_size: int, + context_size: int, + block_id: int, + num_blocks: int, + att_dropout_rate: float = 0.0, + ffn_dropout_rate: float = 0.0, + dropout_rate: float = 0.0, + ) -> None: + """Construct a RWKV object.""" + super().__init__() + + self.layer_norm_att = LayerNorm(size) + self.layer_norm_ffn = LayerNorm(size) + + self.att = DecoderSelfAttention( + size, attention_size, context_size, block_id, att_dropout_rate, num_blocks + ) + self.dropout_att = torch.nn.Dropout(p=dropout_rate) + + self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks) + self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) + + def forward( + self, + x: torch.Tensor, + state: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Compute receptance weighted key value. + + Args: + x: RWKV input sequences. (B, L, size) + state: Decoder hidden states. [5 x (B, D_att/size, N)] + + Returns: + x: RWKV output sequences. (B, L, size) + x: Decoder hidden states. [5 x (B, D_att/size, N)] + + """ + att, state = self.att(self.layer_norm_att(x), state=state) + x = x + self.dropout_att(att) + + ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) + x = x + self.dropout_ffn(ffn) + + return x, state diff --git a/funasr/modules/rwkv_attention.py b/funasr/modules/rwkv_attention.py new file mode 100644 index 000000000..f0c7da39e --- /dev/null +++ b/funasr/modules/rwkv_attention.py @@ -0,0 +1,632 @@ +"""Attention (time mixing) modules for RWKV block. + +Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py. + +Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. + +""" # noqa + +import math +from importlib.util import find_spec +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch + +wkv_kernel_encoder = None +wkv_kernel_decoder = None + +class WKVLinearAttentionEncoder(torch.autograd.Function): + """WKVLinearAttention function definition.""" + + @staticmethod + def forward( + ctx, + time_decay: torch.Tensor, + time_first: torch.Tensor, + key: torch.Tensor, + value: torch.tensor, + ) -> torch.Tensor: + """WKVLinearAttention function forward pass. + + Args: + time_decay: Channel-wise time decay vector. (D_att) + time_first: Channel-wise time first vector. (D_att) + key: Key tensor. (B, U, D_att) + value: Value tensor. (B, U, D_att) + + Returns: + out: Weighted Key-Value tensor. (B, U, D_att) + + """ + batch, length, dim = key.size() + + assert length <= wkv_kernel_encoder.context_size, ( + f"Cannot process key of length {length} while context_size " + f"is ({wkv_kernel_encoder.context_size}). Limit should be increased." + ) + + assert batch * dim % min(dim, 32) == 0, ( + f"batch size ({batch}) by dimension ({dim}) should be a multiple of " + f"{min(dim, 32)}" + ) + + ctx.input_dtype = key.dtype + + time_decay = -torch.exp(time_decay.float().contiguous()) + time_first = time_first.float().contiguous() + + key = key.float().contiguous() + value = value.float().contiguous() + + out = torch.empty_like(key, memory_format=torch.contiguous_format) + + wkv_kernel_encoder.forward(time_decay, time_first, key, value, out) + ctx.save_for_backward(time_decay, time_first, key, value, out) + + return out + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """WKVLinearAttention function backward pass. + + Args: + grad_output: Output gradient. (B, U, D_att) + + Returns: + grad_time_decay: Gradient for channel-wise time decay vector. (D_att) + grad_time_first: Gradient for channel-wise time first vector. (D_att) + grad_key: Gradient for key tensor. (B, U, D_att) + grad_value: Gradient for value tensor. (B, U, D_att) + + """ + time_decay, time_first, key, value, output = ctx.saved_tensors + grad_dtype = ctx.input_dtype + + batch, _, dim = key.size() + + grad_time_decay = torch.empty( + (batch, dim), + memory_format=torch.contiguous_format, + dtype=time_decay.dtype, + device=time_decay.device, + ) + + grad_time_first = torch.empty( + (batch, dim), + memory_format=torch.contiguous_format, + dtype=time_decay.dtype, + device=time_decay.device, + ) + + grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) + grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) + + wkv_kernel_encoder.backward( + time_decay, + time_first, + key, + value, + output, + grad_output.contiguous(), + grad_time_decay, + grad_time_first, + grad_key, + grad_value, + ) + + grad_time_decay = torch.sum(grad_time_decay, dim=0) + grad_time_first = torch.sum(grad_time_first, dim=0) + + return ( + grad_time_decay, + grad_time_first, + grad_key, + grad_value, + ) + +class WKVLinearAttentionDecoder(torch.autograd.Function): + """WKVLinearAttention function definition.""" + + @staticmethod + def forward( + ctx, + time_decay: torch.Tensor, + time_first: torch.Tensor, + key: torch.Tensor, + value: torch.tensor, + ) -> torch.Tensor: + """WKVLinearAttention function forward pass. + + Args: + time_decay: Channel-wise time decay vector. (D_att) + time_first: Channel-wise time first vector. (D_att) + key: Key tensor. (B, U, D_att) + value: Value tensor. (B, U, D_att) + + Returns: + out: Weighted Key-Value tensor. (B, U, D_att) + + """ + batch, length, dim = key.size() + + assert length <= wkv_kernel_decoder.context_size, ( + f"Cannot process key of length {length} while context_size " + f"is ({wkv_kernel.context_size}). Limit should be increased." + ) + + assert batch * dim % min(dim, 32) == 0, ( + f"batch size ({batch}) by dimension ({dim}) should be a multiple of " + f"{min(dim, 32)}" + ) + + ctx.input_dtype = key.dtype + + time_decay = -torch.exp(time_decay.float().contiguous()) + time_first = time_first.float().contiguous() + + key = key.float().contiguous() + value = value.float().contiguous() + + out = torch.empty_like(key, memory_format=torch.contiguous_format) + + wkv_kernel_decoder.forward(time_decay, time_first, key, value, out) + ctx.save_for_backward(time_decay, time_first, key, value, out) + + return out + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """WKVLinearAttention function backward pass. + + Args: + grad_output: Output gradient. (B, U, D_att) + + Returns: + grad_time_decay: Gradient for channel-wise time decay vector. (D_att) + grad_time_first: Gradient for channel-wise time first vector. (D_att) + grad_key: Gradient for key tensor. (B, U, D_att) + grad_value: Gradient for value tensor. (B, U, D_att) + + """ + time_decay, time_first, key, value, output = ctx.saved_tensors + grad_dtype = ctx.input_dtype + + batch, _, dim = key.size() + + grad_time_decay = torch.empty( + (batch, dim), + memory_format=torch.contiguous_format, + dtype=time_decay.dtype, + device=time_decay.device, + ) + + grad_time_first = torch.empty( + (batch, dim), + memory_format=torch.contiguous_format, + dtype=time_decay.dtype, + device=time_decay.device, + ) + + grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) + grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) + + wkv_kernel_decoder.backward( + time_decay, + time_first, + key, + value, + output, + grad_output.contiguous(), + grad_time_decay, + grad_time_first, + grad_key, + grad_value, + ) + + grad_time_decay = torch.sum(grad_time_decay, dim=0) + grad_time_first = torch.sum(grad_time_first, dim=0) + + return ( + grad_time_decay, + grad_time_first, + grad_key, + grad_value, + ) + +def load_encoder_wkv_kernel(context_size: int) -> None: + """Load WKV CUDA kernel. + + Args: + context_size: Context size. + + """ + from torch.utils.cpp_extension import load + + global wkv_kernel_encoder + + if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size: + return + + if find_spec("ninja") is None: + raise ImportError( + "Ninja package was not found. WKV kernel module can't be loaded " + "for training. Please, 'pip install ninja' in your environment." + ) + + if not torch.cuda.is_available(): + raise ImportError( + "CUDA is currently a requirement for WKV kernel loading. " + "Please set your devices properly and launch again." + ) + + kernel_folder = Path(__file__).resolve().parent / "cuda_encoder" + kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] + + kernel_cflags = [ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + f"-DTmax={context_size}", + ] + wkv_kernel_encoder = load( + name=f"encoder_wkv_{context_size}", + sources=kernel_files, + verbose=True, + extra_cuda_cflags=kernel_cflags, + ) + wkv_kernel_encoder.context_size = context_size + +def load_decoder_wkv_kernel(context_size: int) -> None: + """Load WKV CUDA kernel. + + Args: + context_size: Context size. + + """ + from torch.utils.cpp_extension import load + + global wkv_kernel_decoder + + if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size: + return + + if find_spec("ninja") is None: + raise ImportError( + "Ninja package was not found. WKV kernel module can't be loaded " + "for training. Please, 'pip install ninja' in your environment." + ) + + if not torch.cuda.is_available(): + raise ImportError( + "CUDA is currently a requirement for WKV kernel loading. " + "Please set your devices properly and launch again." + ) + + kernel_folder = Path(__file__).resolve().parent / "cuda_decoder" + kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] + + kernel_cflags = [ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + f"-DTmax={context_size}", + ] + wkv_kernel_decoder = load( + name=f"decoder_wkv_{context_size}", + sources=kernel_files, + verbose=True, + extra_cuda_cflags=kernel_cflags, + ) + wkv_kernel_decoder.context_size = context_size + +class SelfAttention(torch.nn.Module): + """SelfAttention module definition. + + Args: + size: Input/Output size. + attention_size: Attention hidden size. + context_size: Context size for WKV kernel. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + + """ + + def __init__( + self, + size: int, + attention_size: int, + block_id: int, + dropout_rate: float, + num_blocks: int, + ) -> None: + """Construct a SelfAttention object.""" + super().__init__() + self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) + + self.time_decay = torch.nn.Parameter(torch.empty(attention_size)) + self.time_first = torch.nn.Parameter(torch.empty(attention_size)) + + self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) + self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size)) + self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) + + self.proj_key = torch.nn.Linear(size, attention_size, bias=True) + self.proj_value = torch.nn.Linear(size, attention_size, bias=True) + self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True) + + self.proj_output = torch.nn.Linear(attention_size, size, bias=True) + + self.block_id = block_id + + self.reset_parameters(size, attention_size, block_id, num_blocks) + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def reset_parameters( + self, size: int, attention_size: int, block_id: int, num_blocks: int + ) -> None: + """Reset module parameters. + + Args: + size: Block size. + attention_size: Attention hidden size. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + + """ + ratio_0_to_1 = block_id / (num_blocks - 1) + ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) + + time_weight = torch.ones(1, 1, size) + + for i in range(size): + time_weight[0, 0, i] = i / size + + decay_speed = [ + -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + for h in range(attention_size) + ] + decay_speed = torch.tensor( + decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device + ) + + zigzag = ( + torch.tensor( + [(i + 1) % 3 - 1 for i in range(attention_size)], + dtype=self.time_first.dtype, + device=self.time_first.device, + ) + * 0.5 + ) + + with torch.no_grad(): + self.time_decay.data = decay_speed + self.time_first.data = torch.ones_like( + self.time_first * math.log(0.3) + zigzag + ) + + self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + self.time_mix_value.data = ( + torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + self.time_mix_receptance.data = torch.pow( + time_weight, 0.5 * ratio_1_to_almost0 + ) + + @torch.no_grad() + def wkv_linear_attention( + self, + time_decay: torch.Tensor, + time_first: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Compute WKV with state (i.e.: for inference). + + Args: + time_decay: Channel-wise time decay vector. (D_att) + time_first: Channel-wise time first vector. (D_att) + key: Key tensor. (B, 1, D_att) + value: Value tensor. (B, 1, D_att) + state: Decoder hidden states. [3 x (B, D_att)] + + Returns: + output: Weighted Key-Value. (B, 1, D_att) + state: Decoder hidden states. [3 x (B, 1, D_att)] + + """ + num_state, den_state, max_state = state + + max_for_output = torch.maximum(max_state, (time_first + key)) + + e1 = torch.exp(max_state - max_for_output) + e2 = torch.exp((time_first + key) - max_for_output) + + numerator = e1 * num_state + e2 * value + denominator = e1 * den_state + e2 + + max_for_state = torch.maximum(key, (max_state + time_decay)) + + e1 = torch.exp((max_state + time_decay) - max_for_state) + e2 = torch.exp(key - max_for_state) + + wkv = numerator / denominator + + state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state] + + return wkv, state + + +class DecoderSelfAttention(SelfAttention): + """SelfAttention module definition. + + Args: + size: Input/Output size. + attention_size: Attention hidden size. + context_size: Context size for WKV kernel. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + + """ + + def __init__( + self, + size: int, + attention_size: int, + context_size: int, + block_id: int, + dropout_rate: float, + num_blocks: int, + ) -> None: + """Construct a SelfAttention object.""" + super().__init__( + size, + attention_size, + block_id, + dropout_rate, + num_blocks + ) + load_decoder_wkv_kernel(context_size) + + def forward( + self, + x: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + """Compute time mixing. + + Args: + x: SelfAttention input sequences. (B, U, size) + state: Decoder hidden states. [5 x (B, 1, D_att, N)] + + Returns: + x: SelfAttention output sequences. (B, U, size) + + """ + shifted_x = ( + self.time_shift(x) if state is None else state[1][..., self.block_id] + ) + + key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) + value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) + receptance = x * self.time_mix_receptance + shifted_x * ( + 1 - self.time_mix_receptance + ) + + key = self.proj_key(key) + value = self.proj_value(value) + receptance = torch.sigmoid(self.proj_receptance(receptance)) + + if state is not None: + state[1][..., self.block_id] = x + + wkv, att_state = self.wkv_linear_attention( + self.time_decay, + self.time_first, + key, + value, + tuple(s[..., self.block_id] for s in state[2:]), + ) + + state[2][..., self.block_id] = att_state[0] + state[3][..., self.block_id] = att_state[1] + state[4][..., self.block_id] = att_state[2] + else: + wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value) + + wkv = self.dropout(wkv) + x = self.proj_output(receptance * wkv) + + return x, state + +class EncoderSelfAttention(SelfAttention): + """SelfAttention module definition. + + Args: + size: Input/Output size. + attention_size: Attention hidden size. + context_size: Context size for WKV kernel. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + + """ + + def __init__( + self, + size: int, + attention_size: int, + context_size: int, + block_id: int, + dropout_rate: float, + num_blocks: int, + ) -> None: + """Construct a SelfAttention object.""" + super().__init__( + size, + attention_size, + block_id, + dropout_rate, + num_blocks + ) + load_encoder_wkv_kernel(context_size) + + def forward( + self, + x: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + """Compute time mixing. + + Args: + x: SelfAttention input sequences. (B, U, size) + state: Decoder hidden states. [5 x (B, 1, D_att, N)] + + Returns: + x: SelfAttention output sequences. (B, U, size) + + """ + shifted_x = ( + self.time_shift(x) if state is None else state[1][..., self.block_id] + ) + + key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) + value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) + receptance = x * self.time_mix_receptance + shifted_x * ( + 1 - self.time_mix_receptance + ) + + key = self.proj_key(key) + value = self.proj_value(value) + receptance = torch.sigmoid(self.proj_receptance(receptance)) + + if state is not None: + state[1][..., self.block_id] = x + + wkv, att_state = self.wkv_linear_attention( + self.time_decay, + self.time_first, + key, + value, + tuple(s[..., self.block_id] for s in state[2:]), + ) + + state[2][..., self.block_id] = att_state[0] + state[3][..., self.block_id] = att_state[1] + state[4][..., self.block_id] = att_state[2] + else: + wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value) + + wkv = self.dropout(wkv) + x = self.proj_output(receptance * wkv) + + return x, state + diff --git a/funasr/modules/rwkv_feed_forward.py b/funasr/modules/rwkv_feed_forward.py new file mode 100644 index 000000000..ddb42859e --- /dev/null +++ b/funasr/modules/rwkv_feed_forward.py @@ -0,0 +1,97 @@ +"""Feed-forward (channel mixing) module for RWKV block. + +Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py + +Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. + +""" # noqa + +from typing import List, Optional, Tuple + +import torch + + +class FeedForward(torch.nn.Module): + """FeedForward module definition. + + Args: + size: Input/Output size. + hidden_size: Hidden size. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + + """ + + def __init__( + self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int + ) -> None: + """Construct a FeedForward object.""" + super().__init__() + + self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) + + self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) + self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) + + self.proj_key = torch.nn.Linear(size, hidden_size, bias=True) + self.proj_value = torch.nn.Linear(hidden_size, size, bias=True) + self.proj_receptance = torch.nn.Linear(size, size, bias=True) + + self.block_id = block_id + + self.reset_parameters(size, block_id, num_blocks) + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None: + """Reset module parameters. + + Args: + size: Block size. + block_id: Block index. + num_blocks: Number of blocks in the architecture. + + """ + ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) + + time_weight = torch.ones(1, 1, size) + + for i in range(size): + time_weight[0, 0, i] = i / size + + with torch.no_grad(): + self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) + + def forward( + self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + """Compute channel mixing. + + Args: + x: FeedForward input sequences. (B, U, size) + state: Decoder hidden state. [5 x (B, 1, size, N)] + + Returns: + x: FeedForward output sequences. (B, U, size) + state: Decoder hidden state. [5 x (B, 1, size, N)] + + """ + shifted_x = ( + self.time_shift(x) if state is None else state[0][..., self.block_id] + ) + + key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) + receptance = x * self.time_mix_receptance + shifted_x * ( + 1 - self.time_mix_receptance + ) + + key = torch.square(torch.relu(self.proj_key(key))) + value = self.proj_value(self.dropout(key)) + receptance = torch.sigmoid(self.proj_receptance(receptance)) + + if state is not None: + state[0][..., self.block_id] = x + + x = receptance * value + + return x, state diff --git a/funasr/modules/rwkv_subsampling.py b/funasr/modules/rwkv_subsampling.py new file mode 100644 index 000000000..427709312 --- /dev/null +++ b/funasr/modules/rwkv_subsampling.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Subsampling layer definition.""" +import numpy as np +import torch +import torch.nn.functional as F +from funasr.modules.embedding import PositionalEncoding +import logging +from funasr.modules.streaming_utils.utils import sequence_mask +from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len +from typing import Optional, Tuple, Union +import math + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message (str): Message for error catch + actual_size (int): the short size that cannot pass the subsampling + limit (int): the limit size for subsampling + + """ + + def __init__(self, message, actual_size, limit): + """Construct a TooShortUttError for error handler.""" + super().__init__(message) + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(ins, size): + """Check if the utterance is too short for subsampling.""" + if isinstance(ins, Conv2dSubsampling2) and size < 3: + return True, 3 + if isinstance(ins, Conv2dSubsampling) and size < 7: + return True, 7 + if isinstance(ins, Conv2dSubsampling6) and size < 11: + return True, 11 + if isinstance(ins, Conv2dSubsampling8) and size < 15: + return True, 15 + return False, -1 + + +class RWKVConvInput(torch.nn.Module): + """Streaming ConvInput module definition. + Args: + input_size: Input size. + conv_size: Convolution size. + subsampling_factor: Subsampling factor. + output_size: Block output dimension. + """ + + def __init__( + self, + input_size: int, + conv_size: Union[int, Tuple], + subsampling_factor: int = 4, + conv_kernel_size: int = 3, + output_size: Optional[int] = None, + ) -> None: + """Construct a ConvInput object.""" + super().__init__() + if subsampling_factor == 1: + conv_size1, conv_size2, conv_size3 = conv_size + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + ) + + output_proj = conv_size3 * ((input_size // 2) // 2) + + self.subsampling_factor = 1 + + self.stride_1 = 1 + + self.create_new_mask = self.create_new_vgg_mask + + else: + conv_size1, conv_size2, conv_size3 = conv_size + + kernel_1 = int(subsampling_factor / 2) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), + torch.nn.ReLU(), + ) + + output_proj = conv_size3 * ((input_size // 2) // 2) + + self.subsampling_factor = subsampling_factor + + self.create_new_mask = self.create_new_vgg_mask + + self.stride_1 = kernel_1 + + self.min_frame_length = 7 + + if output_size is not None: + self.output = torch.nn.Linear(output_proj, output_size) + self.output_size = output_size + else: + self.output = None + self.output_size = output_proj + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: ConvInput input sequences. (B, T, D_feats) + mask: Mask of input sequences. (B, 1, T) + Returns: + x: ConvInput output sequences. (B, sub(T), D_out) + mask: Mask of output sequences. (B, 1, sub(T)) + """ + if mask is not None: + mask = self.create_new_mask(mask) + olens = max(mask.eq(0).sum(1)) + + b, t, f = x.size() + x = x.unsqueeze(1) # (b. 1. t. f) + + if chunk_size is not None: + max_input_length = int( + chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) + ) + x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) + x = list(x) + x = torch.stack(x, dim=0) + N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) + x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) + + x = self.conv(x) + + _, c, _, f = x.size() + if chunk_size is not None: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] + else: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f) + + if self.output is not None: + x = self.output(x) + + return x, mask[:,:olens][:,:x.size(1)] + + def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create a new mask for VGG output sequences. + Args: + mask: Mask of input sequences. (B, T) + Returns: + mask: Mask of output sequences. (B, sub(T)) + """ + if self.subsampling_factor > 1: + return mask[:, ::2][:, ::self.stride_1] + else: + return mask + + def get_size_before_subsampling(self, size: int) -> int: + """Return the original size before subsampling for a given size. + Args: + size: Number of frames after subsampling. + Returns: + : Number of frames before subsampling. + """ + return size * self.subsampling_factor diff --git a/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java b/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java index be14bd3b3..f45877c01 100644 --- a/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java +++ b/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java @@ -39,6 +39,8 @@ public class MainActivity extends AppCompatActivity { public static final String TAG = MainActivity.class.getSimpleName(); // WebSocket地址 public String ASR_HOST = ""; + // 官方WebSocket地址 + public static final String DEFAULT_HOST = "wss://101.37.77.25:10088"; // 发送的JSON数据 public static final String MODE = "2pass"; public static final String CHUNK_SIZE = "5, 10, 5"; @@ -61,7 +63,6 @@ public class MainActivity extends AppCompatActivity { // 控件 private Button recordBtn; private TextView resultText; - private WebSocket webSocket; @SuppressLint("ClickableViewAccessibility") @Override @@ -106,8 +107,8 @@ public class MainActivity extends AppCompatActivity { ASR_HOST = uri; } // 读取热词 - String hotWords = sharedPreferences.getString("hotwords", ""); - if (!hotWords.equals("")) { + String hotWords = sharedPreferences.getString("hotwords", null); + if (hotWords != null) { this.hotWords = hotWords; } } @@ -150,6 +151,14 @@ public class MainActivity extends AppCompatActivity { editor.apply(); } }); + builder.setNeutralButton("使用官方服务", (dialog, id) -> { + ASR_HOST = DEFAULT_HOST; + input.setText(DEFAULT_HOST); + Toast.makeText(MainActivity.this, "WebSocket地址:" + ASR_HOST, Toast.LENGTH_SHORT).show(); + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.putString("uri", ASR_HOST); + editor.apply(); + }); AlertDialog dialog = builder.create(); dialog.show(); } @@ -166,12 +175,10 @@ public class MainActivity extends AppCompatActivity { builder.setView(view); builder.setPositiveButton("确定", (dialog, id) -> { String hotwords = input.getText().toString(); - if (!hotwords.equals("")) { - this.hotWords = hotwords; - SharedPreferences.Editor editor = sharedPreferences.edit(); - editor.putString("hotwords", hotwords); - editor.apply(); - } + this.hotWords = hotwords; + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.putString("hotwords", hotwords); + editor.apply(); }); AlertDialog dialog = builder.create(); dialog.show(); @@ -225,7 +232,7 @@ public class MainActivity extends AppCompatActivity { Request request = new Request.Builder() .url(ASR_HOST) .build(); - webSocket = client.newWebSocket(request, new WebSocketListener() { + WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() { @Override public void onOpen(@NonNull WebSocket webSocket, @NonNull Response response) { @@ -311,7 +318,9 @@ public class MainActivity extends AppCompatActivity { obj.put("chunk_size", array); obj.put("chunk_interval", CHUNK_INTERVAL); obj.put("wav_name", "default"); - obj.put("hotwords", hotWords); + if (!hotWords.equals("")) { + obj.put("hotwords", hotWords); + } obj.put("wav_format", "pcm"); obj.put("is_speaking", isSpeaking); return obj.toString(); diff --git a/funasr/runtime/docs/SDK_advanced_guide_offline.md b/funasr/runtime/docs/SDK_advanced_guide_offline.md index 0348308db..43a69cd83 100644 --- a/funasr/runtime/docs/SDK_advanced_guide_offline.md +++ b/funasr/runtime/docs/SDK_advanced_guide_offline.md @@ -83,7 +83,8 @@ nohup bash run_server.sh \ --io-thread-num 8 \ --port 10095 \ --certfile ../../../ssl_key/server.crt \ - --keyfile ../../../ssl_key/server.key > log.out 2>&1 & + --keyfile ../../../ssl_key/server.key \ + --hotword ../../hotwords.txt > log.out 2>&1 & ``` Introduction to run_server.sh parameters: @@ -102,6 +103,7 @@ Introduction to run_server.sh parameters: --io-thread-num: Number of IO threads that the server starts. Default is 1. --certfile : SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0 --keyfile : SSL key file. Default is ../../../ssl_key/server.key. +--hotword Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt ``` ### Shutting Down the FunASR Service diff --git a/funasr/runtime/docs/SDK_advanced_guide_offline_en.md b/funasr/runtime/docs/SDK_advanced_guide_offline_en.md index c2599ecfe..b829e67c7 100644 --- a/funasr/runtime/docs/SDK_advanced_guide_offline_en.md +++ b/funasr/runtime/docs/SDK_advanced_guide_offline_en.md @@ -79,7 +79,7 @@ nohup bash run_server.sh \ --io-thread-num 8 \ --port 10095 \ --certfile ../../../ssl_key/server.crt \ - --keyfile ../../../ssl_key/server.key + --keyfile ../../../ssl_key/server.key > log.out 2>&1 & ``` Introduction to run_server.sh parameters: diff --git a/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md b/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md index c63109793..ee6550108 100644 --- a/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md +++ b/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md @@ -165,7 +165,8 @@ nohup bash run_server.sh \ --io-thread-num 8 \ --port 10095 \ --certfile ../../../ssl_key/server.crt \ - --keyfile ../../../ssl_key/server.key > log.out 2>&1 & + --keyfile ../../../ssl_key/server.key \ + --hotword ../../hotwords.txt > log.out 2>&1 & ``` **run_server.sh命令参数介绍** ```text @@ -182,6 +183,7 @@ nohup bash run_server.sh \ --io-thread-num 服务端启动的IO线程数,默认为 1 --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0 --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key +--hotword 热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt ``` ### 关闭FunASR服务 diff --git a/funasr/runtime/docs/SDK_advanced_guide_online.md b/funasr/runtime/docs/SDK_advanced_guide_online.md index 17fb8916e..ddc02cf89 100644 --- a/funasr/runtime/docs/SDK_advanced_guide_online.md +++ b/funasr/runtime/docs/SDK_advanced_guide_online.md @@ -72,7 +72,8 @@ nohup bash run_server_2pass.sh \ --io-thread-num 8 \ --port 10095 \ --certfile ../../../ssl_key/server.crt \ - --keyfile ../../../ssl_key/server.key > log.out 2>&1 & + --keyfile ../../../ssl_key/server.key \ + --hotword ../../hotwords.txt > log.out 2>&1 & # If you want to close ssl,please add:--certfile 0 # If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model: @@ -97,6 +98,7 @@ nohup bash run_server_2pass.sh \ --io-thread-num: Number of IO threads that the server starts. Default is 1. --certfile : SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0 --keyfile : SSL key file. Default is ../../../ssl_key/server.key. +--hotword Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt ``` ### Shutting Down the FunASR Service diff --git a/funasr/runtime/docs/SDK_advanced_guide_online_zh.md b/funasr/runtime/docs/SDK_advanced_guide_online_zh.md index 232701e2c..902ae7a6c 100644 --- a/funasr/runtime/docs/SDK_advanced_guide_online_zh.md +++ b/funasr/runtime/docs/SDK_advanced_guide_online_zh.md @@ -31,7 +31,8 @@ nohup bash run_server_2pass.sh \ --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \ --online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \ --punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \ - --itn-dir thuduj12/fst_itn_zh > log.out 2>&1 & + --itn-dir thuduj12/fst_itn_zh \ + --hotwordsfile ../../hotwords.txt > log.out 2>&1 & # 如果您想关闭ssl,增加参数:--certfile 0 # 如果您想使用时间戳或者热词模型进行部署,请设置--model-dir为对应模型: @@ -80,7 +81,8 @@ nohup bash run_server_2pass.sh \ --io-thread-num 8 \ --port 10095 \ --certfile ../../../ssl_key/server.crt \ - --keyfile ../../../ssl_key/server.key > log.out 2>&1 & + --keyfile ../../../ssl_key/server.key \ + --hotword ../../hotwords.txt > log.out 2>&1 & ``` **run_server_2pass.sh命令参数介绍** ```text @@ -98,6 +100,7 @@ nohup bash run_server_2pass.sh \ --io-thread-num 服务端启动的IO线程数,默认为 1 --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0 --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key +--hotword 热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt ``` ### 关闭FunASR服务 diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py index 3a01812e8..7d0060cc5 100644 --- a/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py +++ b/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py @@ -3,7 +3,7 @@ import numpy as np def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0, total_offset=-1.5): if not len(char_list): - return [] + return '', [] START_END_THRESHOLD = 5 MAX_TOKEN_DURATION = 30 TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled diff --git a/funasr/runtime/run_server.sh b/funasr/runtime/run_server.sh index 6869fd9f7..f75f15946 100644 --- a/funasr/runtime/run_server.sh +++ b/funasr/runtime/run_server.sh @@ -9,6 +9,7 @@ io_thread_num=8 port=10095 certfile="../../../ssl_key/server.crt" keyfile="../../../ssl_key/server.key" +hotwordsfile="../../hotwords.txt" . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1; @@ -24,7 +25,8 @@ if [ -z "$certfile" ] || [ "$certfile" -eq 0 ]; then --io-thread-num ${io_thread_num} \ --port ${port} \ --certfile "" \ - --keyfile "" + --keyfile "" \ + --hotwordsfile ${hotwordsfile} else ./funasr-wss-server \ --download-model-dir ${download_model_dir} \ @@ -36,5 +38,6 @@ else --io-thread-num ${io_thread_num} \ --port ${port} \ --certfile ${certfile} \ - --keyfile ${keyfile} + --keyfile ${keyfile} \ + --hotwordsfile ${hotwordsfile} fi diff --git a/funasr/runtime/run_server_2pass.sh b/funasr/runtime/run_server_2pass.sh index 63c2041cd..941064cd3 100644 --- a/funasr/runtime/run_server_2pass.sh +++ b/funasr/runtime/run_server_2pass.sh @@ -10,6 +10,7 @@ io_thread_num=8 port=10095 certfile="../../../ssl_key/server.crt" keyfile="../../../ssl_key/server.key" +hotwordsfile="../../hotwords.txt" . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1; @@ -26,7 +27,8 @@ if [ -z "$certfile" ] || [ "$certfile" -eq 0 ]; then --io-thread-num ${io_thread_num} \ --port ${port} \ --certfile "" \ - --keyfile "" + --keyfile "" \ + --hotwordsfile ${hotwordsfile} else ./funasr-wss-server-2pass \ --download-model-dir ${download_model_dir} \ @@ -39,5 +41,6 @@ else --io-thread-num ${io_thread_num} \ --port ${port} \ --certfile ${certfile} \ - --keyfile ${keyfile} + --keyfile ${keyfile} \ + --hotwordsfile ${hotwordsfile} fi diff --git a/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp index 1f8b63269..1c879578c 100644 --- a/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp +++ b/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp @@ -14,6 +14,9 @@ #include #include "websocket-server-2pass.h" +#include +std::string hotwords = ""; + using namespace std; void GetValue(TCLAP::ValueArg& value_arg, string key, std::map& model_path) { @@ -109,6 +112,15 @@ int main(int argc, char* argv[]) { "connection", false, "../../../ssl_key/server.key", "string"); + TCLAP::ValueArg hotwordsfile( + "", "hotword", + "default: ../../hotwords.txt, path of hotwordsfile" + "connection", + false, "../../hotwords.txt", "string"); + + // add file + cmd.add(hotwordsfile); + cmd.add(certfile); cmd.add(keyfile); @@ -417,6 +429,21 @@ int main(int argc, char* argv[]) { std::string s_certfile = certfile.getValue(); std::string s_keyfile = keyfile.getValue(); + std::string s_hotwordsfile = hotwordsfile.getValue(); + std::string line; + std::ifstream file(s_hotwordsfile); + LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile; + + if (file.is_open()) { + while (getline(file, line)) { + hotwords += line+HOTWORD_SEP; + } + LOG(INFO) << "hotwords: " << hotwords; + file.close(); + } else { + LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile; + } + bool is_ssl = false; if (!s_certfile.empty()) { is_ssl = true; @@ -460,8 +487,7 @@ int main(int argc, char* argv[]) { websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model } - std::cout << "asr model init finished. listen on port:" << s_port - << std::endl; + LOG(INFO) << "asr model init finished. listen on port:" << s_port; // Start the ASIO network io_service run loop std::vector ts; @@ -480,7 +506,7 @@ int main(int argc, char* argv[]) { } } catch (std::exception const& e) { - std::cerr << "Error: " << e.what() << std::endl; + LOG(ERROR) << "Error: " << e.what(); } return 0; diff --git a/funasr/runtime/websocket/bin/funasr-wss-server.cpp b/funasr/runtime/websocket/bin/funasr-wss-server.cpp index 55ce07ba5..b571dbef2 100644 --- a/funasr/runtime/websocket/bin/funasr-wss-server.cpp +++ b/funasr/runtime/websocket/bin/funasr-wss-server.cpp @@ -13,6 +13,9 @@ #include "websocket-server.h" #include +#include +std::string hotwords = ""; + using namespace std; void GetValue(TCLAP::ValueArg& value_arg, string key, std::map& model_path) { @@ -95,6 +98,15 @@ int main(int argc, char* argv[]) { "default: ../../../ssl_key/server.key, path of keyfile for WSS connection", false, "../../../ssl_key/server.key", "string"); + TCLAP::ValueArg hotwordsfile( + "", "hotword", + "default: ../../hotwords.txt, path of hotwordsfile" + "connection", + false, "../../hotwords.txt", "string"); + + // add file + cmd.add(hotwordsfile); + cmd.add(certfile); cmd.add(keyfile); @@ -331,6 +343,21 @@ int main(int argc, char* argv[]) { std::string s_certfile = certfile.getValue(); std::string s_keyfile = keyfile.getValue(); + std::string s_hotwordsfile = hotwordsfile.getValue(); + std::string line; + std::ifstream file(s_hotwordsfile); + LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile; + + if (file.is_open()) { + while (getline(file, line)) { + hotwords += line+HOTWORD_SEP; + } + LOG(INFO) << "hotwords: " << hotwords; + file.close(); + } else { + LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile; + } + bool is_ssl = false; if (!s_certfile.empty()) { is_ssl = true; @@ -374,8 +401,7 @@ int main(int argc, char* argv[]) { websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model } - std::cout << "asr model init finished. listen on port:" << s_port - << std::endl; + LOG(INFO) << "asr model init finished. listen on port:" << s_port; // Start the ASIO network io_service run loop std::vector ts; @@ -394,7 +420,7 @@ int main(int argc, char* argv[]) { } } catch (std::exception const& e) { - std::cerr << "Error: " << e.what() << std::endl; + LOG(ERROR) << "Error: " << e.what(); } return 0; diff --git a/funasr/runtime/websocket/bin/websocket-server-2pass.cpp b/funasr/runtime/websocket/bin/websocket-server-2pass.cpp index 107be409a..9e0668f0e 100644 --- a/funasr/runtime/websocket/bin/websocket-server-2pass.cpp +++ b/funasr/runtime/websocket/bin/websocket-server-2pass.cpp @@ -15,7 +15,9 @@ #include #include #include -#include + +extern std::string hotwords; + context_ptr WebSocketServer::on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl, std::string& s_certfile, @@ -354,7 +356,14 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection switch (msg->get_opcode()) { case websocketpp::frame::opcode::text: { - nlohmann::json jsonresult = nlohmann::json::parse(payload); + nlohmann::json jsonresult; + try{ + jsonresult = nlohmann::json::parse(payload); + }catch (std::exception const &e) + { + LOG(ERROR)<msg["wav_name"] = jsonresult["wav_name"]; @@ -370,17 +379,26 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, msg_data->msg["hotwords"] = jsonresult["hotwords"]; if (!msg_data->msg["hotwords"].empty()) { std::string hw = msg_data->msg["hotwords"]; - LOG(INFO)<<"hotwords: " << hw; - std::vector> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); + hw = hw + " " + hotwords; + LOG(INFO) << "hotwords: " << hw; + std::vector> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); msg_data->hotwords_embedding = std::make_shared>>(new_hotwords_embedding); } - }else{ + } else { + if (hotwords.empty()) { std::string hw = ""; LOG(INFO)<<"hotwords: " << hw; std::vector> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); msg_data->hotwords_embedding = std::make_shared>>(new_hotwords_embedding); + }else { + std::string hw = hotwords; + LOG(INFO) << "hotwords: " << hw; + std::vector> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); + msg_data->hotwords_embedding = + std::make_shared>>(new_hotwords_embedding); + } } } if (jsonresult.contains("audio_fs")) { diff --git a/funasr/runtime/websocket/bin/websocket-server.cpp b/funasr/runtime/websocket/bin/websocket-server.cpp index da1ffa57f..134f5fadc 100644 --- a/funasr/runtime/websocket/bin/websocket-server.cpp +++ b/funasr/runtime/websocket/bin/websocket-server.cpp @@ -16,6 +16,8 @@ #include #include +extern std::string hotwords; + context_ptr WebSocketServer::on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl, std::string& s_certfile, @@ -254,7 +256,15 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection switch (msg->get_opcode()) { case websocketpp::frame::opcode::text: { - nlohmann::json jsonresult = nlohmann::json::parse(payload); + nlohmann::json jsonresult; + try{ + jsonresult = nlohmann::json::parse(payload); + }catch (std::exception const &e) + { + LOG(ERROR)<msg["wav_name"] = jsonresult["wav_name"]; } @@ -266,17 +276,26 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, msg_data->msg["hotwords"] = jsonresult["hotwords"]; if (!msg_data->msg["hotwords"].empty()) { std::string hw = msg_data->msg["hotwords"]; - LOG(INFO)<<"hotwords: " << hw; - std::vector> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw); + hw = hw + " " + hotwords; + LOG(INFO) << "hotwords: " << hw; + std::vector> new_hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hw); msg_data->hotwords_embedding = std::make_shared>>(new_hotwords_embedding); } - }else{ + } else { + if (hotwords.empty()) { std::string hw = ""; LOG(INFO)<<"hotwords: " << hw; std::vector> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw); msg_data->hotwords_embedding = std::make_shared>>(new_hotwords_embedding); + }else { + std::string hw = hotwords; + LOG(INFO) << "hotwords: " << hw; + std::vector> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw); + msg_data->hotwords_embedding = + std::make_shared>>(new_hotwords_embedding); + } } } if (jsonresult.contains("audio_fs")) { diff --git a/funasr/runtime/websocket/hotwords.txt b/funasr/runtime/websocket/hotwords.txt new file mode 100644 index 000000000..6179cbc6a --- /dev/null +++ b/funasr/runtime/websocket/hotwords.txt @@ -0,0 +1,2 @@ +阿里巴巴 +通义实验室 diff --git a/funasr/utils/whisper_utils/__init__.py b/funasr/utils/whisper_utils/__init__.py new file mode 100644 index 000000000..e69de29bb