From 0170f534b017653d504a32ad4a6da267f4db09ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 5 Jul 2024 00:17:06 +0800 Subject: [PATCH] sensevoice --- README.md | 32 +- README_zh.md | 2 + .../sense_voice/demo.py | 25 + .../sense_voice/finetune.sh | 69 ++ funasr/models/sense_voice/export_meta.py | 97 ++ funasr/models/sense_voice/model.py | 911 ++++++++++++++++++ funasr/version.txt | 2 +- 7 files changed, 1122 insertions(+), 16 deletions(-) create mode 100644 examples/industrial_data_pretraining/sense_voice/demo.py create mode 100644 examples/industrial_data_pretraining/sense_voice/finetune.sh create mode 100644 funasr/models/sense_voice/export_meta.py create mode 100644 funasr/models/sense_voice/model.py diff --git a/README.md b/README.md index 0cdbf64e6..66242f0b4 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ ## What's new: +- 2024/07/04:[SenseVoice](https://github.com/FunAudioLLM/SenseVoice) is a speech foundation model with multiple speech understanding capabilities, including ASR, LID, SER, and AED. - 2024/07/01: Offline File Transcription Service GPU 1.1 released, optimize BladeDISC model compatibility issues; ref to ([docs](runtime/readme.md)) - 2024/06/27: Offline File Transcription Service GPU 1.0 released, supporting dynamic batch processing and multi-threading concurrency. In the long audio test set, the single-thread RTF is 0.0076, and multi-threads' speedup is 1200+ (compared to 330+ on CPU); ref to ([docs](runtime/readme.md)) - 2024/05/15:emotion recognition models are new supported. [emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary),[emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary),[emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary). currently supports the following categories: 0: angry 1: happy 2: neutral 3: sad 4: unknown. @@ -90,21 +91,22 @@ FunASR has open-sourced a large number of pre-trained models on industrial data. (Note: ⭐ represents the ModelScope model zoo, 🤗 represents the Huggingface model zoo, 🍀 represents the OpenAI model zoo) -| Model Name | Task Details | Training Data | Parameters | -|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------:|:--------------------------------:|:----------:| -| paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M | -| paraformer-zh-streaming
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) )
| speech recognition, streaming | 60000 hours, Mandarin | 220M | -| paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | speech recognition, without timestamps, non-streaming | 50000 hours, English | 220M | -| conformer-en
( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | speech recognition, non-streaming | 50000 hours, English | 220M | -| ct-punc
( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 290M | -| fsmn-vad
( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M | -| fa-zh
( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M | -| cam++
( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M | -| Whisper-large-v2
([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M | -| Whisper-large-v3
([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M | -| Qwen-Audio
([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B | -| Qwen-Audio-Chat
([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B | -| emotion2vec+large
([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M | +| Model Name | Task Details | Training Data | Parameters | +|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------:|:--------------------------------:|:----------:| +| SenseVoiceSmall
([⭐](https://www.modelscope.cn/models/iic/SenseVoiceSmall) [🤗](https://huggingface.co/FunAudioLLM/SenseVoiceSmall) ) | multiple speech understanding capabilities, including ASR, LID, SER, and AED. | 400000 hours | 330M | +| paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M | +| paraformer-zh-streaming
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) )
| speech recognition, streaming | 60000 hours, Mandarin | 220M | +| paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | speech recognition, without timestamps, non-streaming | 50000 hours, English | 220M | +| conformer-en
( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | speech recognition, non-streaming | 50000 hours, English | 220M | +| ct-punc
( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 290M | +| fsmn-vad
( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M | +| fa-zh
( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M | +| cam++
( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M | +| Whisper-large-v2
([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M | +| Whisper-large-v3
([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M | +| Qwen-Audio
([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B | +| Qwen-Audio-Chat
([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B | +| emotion2vec+large
([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M | diff --git a/README_zh.md b/README_zh.md index fe05f13c0..275a30d6f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,7 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥 ## 最新动态 +- 2024/07/04:[SenseVoice](https://github.com/FunAudioLLM/SenseVoice) 是一个基础语音理解模型,具备多种语音理解能力,涵盖了自动语音识别(ASR)、语言识别(LID)、情感识别(SER)以及音频事件检测(AED)。 - 2024/07/01:中文离线文件转写服务GPU版本 1.1发布,优化bladedisc模型兼容性问题;详细信息参阅([部署文档](runtime/readme_cn.md)) - 2024/06/27:中文离线文件转写服务GPU版本 1.0发布,支持动态batch,支持多路并发,在长音频测试集上单线RTF为0.0076,多线加速比为1200+(CPU为330+);详细信息参阅([部署文档](runtime/readme_cn.md)) - 2024/05/15:新增加情感识别模型,[emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary),[emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary),[emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary),输出情感类别为:生气/angry,开心/happy,中立/neutral,难过/sad。 @@ -99,6 +100,7 @@ FunASR开源了大量在工业数据上预训练模型,您可以在[模型许 | 模型名字 | 任务详情 | 训练数据 | 参数量 | |:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:| +| SenseVoiceSmall
([⭐](https://www.modelscope.cn/models/iic/SenseVoiceSmall) [🤗](https://huggingface.co/FunAudioLLM/SenseVoiceSmall) ) | 多种语音理解能力,涵盖了自动语音识别(ASR)、语言识别(LID)、情感识别(SER)以及音频事件检测(AED) | 400000小时,中文 | 330M | | paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | 语音识别,带时间戳输出,非实时 | 60000小时,中文 | 220M | | paraformer-zh-streaming
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 语音识别,实时 | 60000小时,中文 | 220M | | paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | 语音识别,非实时 | 50000小时,英文 | 220M | diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py new file mode 100644 index 000000000..36355710e --- /dev/null +++ b/examples/industrial_data_pretraining/sense_voice/demo.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import sys +from funasr import AutoModel + +model_dir = "iic/SenseVoiceSmall" +input_file = ( + "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" +) + +model = AutoModel( + model=model_dir, +) + +res = model.generate( + input=input_file, + cache={}, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, +) + +print(res) diff --git a/examples/industrial_data_pretraining/sense_voice/finetune.sh b/examples/industrial_data_pretraining/sense_voice/finetune.sh new file mode 100644 index 000000000..be6c53a1a --- /dev/null +++ b/examples/industrial_data_pretraining/sense_voice/finetune.sh @@ -0,0 +1,69 @@ +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +workspace=`pwd` + +# which gpu to train or finetune +export CUDA_VISIBLE_DEVICES="0,1" +gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + +# model_name from model_hub, or model_dir in local path + +## option 1, download model automatically +model_name_or_model_dir="iic/SenseVoiceCTC" + +## option 2, download model by git +#local_path_root=${workspace}/modelscope_models +#mkdir -p ${local_path_root}/${model_name_or_model_dir} +#git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir} +#model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir} + + +# data dir, which contains: train.json, val.json +train_data=${workspace}/data/train_example.jsonl +val_data=${workspace}/data/val_example.jsonl + +# exp output dir +output_dir="./outputs" +log_file="${output_dir}/log.txt" + +deepspeed_config=${workspace}/../../ds_stage1.json + +mkdir -p ${output_dir} +echo "log_file: ${log_file}" + +DISTRIBUTED_ARGS=" + --nnodes ${WORLD_SIZE:-1} \ + --nproc_per_node $gpu_num \ + --node_rank ${RANK:-0} \ + --master_addr ${MASTER_ADDR:-127.0.0.1} \ + --master_port ${MASTER_PORT:-26669} +" + +echo $DISTRIBUTED_ARGS + +# funasr trainer path +train_tool=`dirname $(which funasr)`/train_ds.py + +torchrun $DISTRIBUTED_ARGS \ +${train_tool} \ +++model="${model_name_or_model_dir}" \ +++train_data_set_list="${train_data}" \ +++valid_data_set_list="${val_data}" \ +++dataset_conf.data_split_num=1 \ +++dataset_conf.batch_sampler="BatchSampler" \ +++dataset_conf.batch_size=6000 \ +++dataset_conf.sort_size=1024 \ +++dataset_conf.batch_type="token" \ +++dataset_conf.num_workers=4 \ +++train_conf.max_epoch=50 \ +++train_conf.log_interval=1 \ +++train_conf.resume=true \ +++train_conf.validate_interval=2000 \ +++train_conf.save_checkpoint_interval=2000 \ +++train_conf.keep_nbest_models=20 \ +++train_conf.avg_nbest_model=10 \ +++train_conf.use_deepspeed=false \ +++train_conf.deepspeed_config=${deepspeed_config} \ +++optim_conf.lr=0.0002 \ +++output_dir="${output_dir}" &> ${log_file} \ No newline at end of file diff --git a/funasr/models/sense_voice/export_meta.py b/funasr/models/sense_voice/export_meta.py new file mode 100644 index 000000000..fe09ee149 --- /dev/null +++ b/funasr/models/sense_voice/export_meta.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import types +import torch +import torch.nn as nn +from funasr.register import tables + + +def export_rebuild_model(model, **kwargs): + model.device = kwargs.get("device") + is_onnx = kwargs.get("type", "onnx") == "onnx" + # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") + # model.encoder = encoder_class(model.encoder, onnx=is_onnx) + + from funasr.utils.torch_function import sequence_mask + + model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) + + model.forward = types.MethodType(export_forward, model) + model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) + model.export_input_names = types.MethodType(export_input_names, model) + model.export_output_names = types.MethodType(export_output_names, model) + model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) + model.export_name = types.MethodType(export_name, model) + + model.export_name = "model" + return model + + +def export_forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + language: torch.Tensor, + textnorm: torch.Tensor, + **kwargs, +): + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + language_query = self.embed(language).to(speech.device) + + textnorm_query = self.embed(textnorm).to(speech.device) + speech = torch.cat((textnorm_query, speech), dim=1) + speech_lengths += 1 + + event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( + speech.size(0), 1, 1 + ) + input_query = torch.cat((language_query, event_emo_query), dim=1) + speech = torch.cat((input_query, speech), dim=1) + speech_lengths += 3 + + # Encoder + encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # c. Passed the encoder result and the beam search + ctc_logits = self.ctc.log_softmax(encoder_out) + + return ctc_logits, encoder_out_lens + + +def export_dummy_inputs(self): + speech = torch.randn(2, 30, 560) + speech_lengths = torch.tensor([6, 30], dtype=torch.int32) + language = torch.tensor([0, 0], dtype=torch.int32) + textnorm = torch.tensor([15, 15], dtype=torch.int32) + return (speech, speech_lengths, language, textnorm) + + +def export_input_names(self): + return ["speech", "speech_lengths", "language", "textnorm"] + + +def export_output_names(self): + return ["ctc_logits", "encoder_out_lens"] + + +def export_dynamic_axes(self): + return { + "speech": {0: "batch_size", 1: "feats_length"}, + "speech_lengths": { + 0: "batch_size", + }, + "logits": {0: "batch_size", 1: "logits_length"}, + } + + +def export_name( + self, +): + return "model.onnx" diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py new file mode 100644 index 000000000..cf4f7fb6a --- /dev/null +++ b/funasr/models/sense_voice/model.py @@ -0,0 +1,911 @@ +from typing import Iterable, Optional +import types +import time +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn +from torch.cuda.amp import autocast +from funasr.metrics.compute_acc import compute_accuracy, th_accuracy +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.train_utils.device_funcs import force_gatherable + +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.utils.datadir_writer import DatadirWriter +from funasr.models.ctc.ctc import CTC + +from funasr.register import tables + + +from funasr.models.paraformer.search import Hypothesis + + +class SinusoidalPositionEncoder(torch.nn.Module): + """ """ + + def __int__(self, d_model=80, dropout_rate=0.1): + pass + + def encode( + self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32 + ): + batch_size = positions.size(0) + positions = positions.type(dtype) + device = positions.device + log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / ( + depth / 2 - 1 + ) + inv_timescales = torch.exp( + torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment) + ) + inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) + scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( + inv_timescales, [1, 1, -1] + ) + encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + return encoding.type(dtype) + + def forward(self, x): + batch_size, timesteps, input_dim = x.size() + positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] + position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) + + return x + position_encoding + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): + """Construct an PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.dropout = torch.nn.Dropout(dropout_rate) + self.activation = activation + + def forward(self, x): + """Forward function.""" + return self.w_2(self.dropout(self.activation(self.w_1(x)))) + + +class MultiHeadedAttentionSANM(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__( + self, + n_head, + in_feat, + n_feat, + dropout_rate, + kernel_size, + sanm_shfit=0, + lora_list=None, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.1, + ): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + # self.linear_q = nn.Linear(n_feat, n_feat) + # self.linear_k = nn.Linear(n_feat, n_feat) + # self.linear_v = nn.Linear(n_feat, n_feat) + + self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + self.fsmn_block = nn.Conv1d( + n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False + ) + # padding + left_padding = (kernel_size - 1) // 2 + if sanm_shfit > 0: + left_padding = left_padding + sanm_shfit + right_padding = kernel_size - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): + b, t, d = inputs.size() + if mask is not None: + mask = torch.reshape(mask, (b, -1, 1)) + if mask_shfit_chunk is not None: + mask = mask * mask_shfit_chunk + inputs = inputs * mask + + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x += inputs + x = self.dropout(x) + if mask is not None: + x = x * mask + return x + + def forward_qkv(self, x): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + b, t, d = x.size() + q_k_v = self.linear_q_k_v(x) + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( + 1, 2 + ) # (batch, head, time2, d_k) + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( + 1, 2 + ) # (batch, head, time2, d_k) + + return q_h, k_h, v_h, v + + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + if mask_att_chunk_encoder is not None: + mask = mask * mask_att_chunk_encoder + + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + + min_value = -float( + "inf" + ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q_h, k_h, v_h, v = self.forward_qkv(x) + fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + return att_outs + fsmn_memory + + def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q_h, k_h, v_h, v = self.forward_qkv(x) + if chunk_size is not None and look_back > 0 or look_back == -1: + if cache is not None: + k_h_stride = k_h[:, :, : -(chunk_size[2]), :] + v_h_stride = v_h[:, :, : -(chunk_size[2]), :] + k_h = torch.cat((cache["k"], k_h), dim=2) + v_h = torch.cat((cache["v"], v_h), dim=2) + + cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) + cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) + if look_back != -1: + cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :] + cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :] + else: + cache_tmp = { + "k": k_h[:, :, : -(chunk_size[2]), :], + "v": v_h[:, :, : -(chunk_size[2]), :], + } + cache = cache_tmp + fsmn_memory = self.forward_fsmn(v, None) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, None) + return att_outs + fsmn_memory, cache + + +class LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): + if maxlen is None: + maxlen = lengths.max() + row_vector = torch.arange(0, maxlen, 1).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + mask = mask.detach() + + return mask.type(dtype).to(device) if device is not None else mask.type(dtype) + + +class EncoderLayerSANM(nn.Module): + def __init__( + self, + in_size, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayerSANM, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(in_size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.in_size = in_size + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + self.dropout_rate = dropout_rate + + def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if self.concat_after: + x_concat = torch.cat( + ( + x, + self.self_attn( + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ), + ), + dim=-1, + ) + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = stoch_layer_coeff * self.concat_linear(x_concat) + else: + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn( + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ) + ) + else: + x = stoch_layer_coeff * self.dropout( + self.self_attn( + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ) + ) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder + + def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if self.in_size == self.size: + attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) + x = residual + attn + else: + x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) + + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.feed_forward(x) + if not self.normalize_before: + x = self.norm2(x) + + return x, cache + + +@tables.register("encoder_classes", "SenseVoiceEncoderSmall") +class SenseVoiceEncoderSmall(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition + https://arxiv.org/abs/2006.01713 + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + tp_blocks: int = 0, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + stochastic_depth_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + kernel_size: int = 11, + sanm_shfit: int = 0, + selfattention_layer_type: str = "sanm", + **kwargs, + ): + super().__init__() + self._output_size = output_size + + self.embed = SinusoidalPositionEncoder() + + self.normalize_before = normalize_before + + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + kernel_size, + sanm_shfit, + ) + + self.encoders0 = nn.ModuleList( + [ + EncoderLayerSANM( + input_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + ) + for i in range(1) + ] + ) + self.encoders = nn.ModuleList( + [ + EncoderLayerSANM( + output_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + ) + for i in range(num_blocks - 1) + ] + ) + + self.tp_encoders = nn.ModuleList( + [ + EncoderLayerSANM( + output_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + ) + for i in range(tp_blocks) + ] + ) + + self.after_norm = LayerNorm(output_size) + + self.tp_norm = LayerNorm(output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + ): + """Embed positions in tensor.""" + masks = sequence_mask(ilens, device=ilens.device)[:, None, :] + + xs_pad *= self.output_size() ** 0.5 + + xs_pad = self.embed(xs_pad) + + # forward encoder1 + for layer_idx, encoder_layer in enumerate(self.encoders0): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.after_norm(xs_pad) + + # forward encoder2 + olens = masks.squeeze(1).sum(1).int() + + for layer_idx, encoder_layer in enumerate(self.tp_encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.tp_norm(xs_pad) + return xs_pad, olens + + +@tables.register("model_classes", "SenseVoiceSmall") +class SenseVoiceSmall(nn.Module): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + encoder: str = None, + encoder_conf: dict = None, + ctc_conf: dict = None, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + length_normalized_loss: bool = False, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + if ctc_conf is None: + ctc_conf = {} + ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) + + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.specaug = specaug + self.normalize = normalize + self.encoder = encoder + self.error_calculator = None + + self.ctc = ctc + + self.length_normalized_loss = length_normalized_loss + self.encoder_output_size = encoder_output_size + + self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} + self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} + self.textnorm_dict = {"withitn": 14, "woitn": 15} + self.textnorm_int_dict = {25016: 14, 25017: 15} + self.embed = torch.nn.Embedding( + 7 + len(self.lid_dict) + len(self.textnorm_dict), input_size + ) + + self.criterion_att = LabelSmoothingLoss( + size=self.vocab_size, + padding_idx=self.ignore_id, + smoothing=kwargs.get("lsm_weight", 0.0), + normalize_length=self.length_normalized_loss, + ) + + @staticmethod + def from_pretrained(model: str = None, **kwargs): + from funasr import AutoModel + + model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) + + return model, kwargs + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ): + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + # import pdb; + # pdb.set_trace() + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text) + + loss_ctc, cer_ctc = None, None + loss_rich, acc_rich = None, None + stats = dict() + + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4 + ) + + loss_rich, acc_rich = self._calc_rich_ce_loss(encoder_out[:, :4, :], text[:, :4]) + + loss = loss_ctc + # Collect total loss stats + stats["loss"] = torch.clone(loss.detach()) if loss_ctc is not None else None + stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None + stats["acc_rich"] = acc_rich + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + **kwargs, + ): + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + lids = torch.LongTensor( + [ + [ + ( + self.lid_int_dict[int(lid)] + if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict + else 0 + ) + ] + for lid in text[:, 0] + ] + ).to(speech.device) + language_query = self.embed(lids) + + styles = torch.LongTensor( + [[self.textnorm_int_dict[int(style)]] for style in text[:, 3]] + ).to(speech.device) + style_query = self.embed(styles) + speech = torch.cat((style_query, speech), dim=1) + speech_lengths += 1 + + event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( + speech.size(0), 1, 1 + ) + input_query = torch.cat((language_query, event_emo_query), dim=1) + speech = torch.cat((input_query, speech), dim=1) + speech_lengths += 3 + + encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) + + return encoder_out, encoder_out_lens + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def _calc_rich_ce_loss( + self, + encoder_out: torch.Tensor, + ys_pad: torch.Tensor, + ): + decoder_out = self.ctc.ctc_lo(encoder_out) + # 2. Compute attention loss + loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous()) + acc_rich = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_pad.contiguous(), + ignore_label=self.ignore_id, + ) + + return loss_rich, acc_rich + + def inference( + self, + data_in, + data_lengths=None, + key: list = ["wav_file_tmp_name"], + tokenizer=None, + frontend=None, + **kwargs, + ): + + meta_data = {} + if ( + isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" + ): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video( + data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + ) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank( + audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend + ) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = ( + speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + ) + + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + language = kwargs.get("language", "auto") + language_query = self.embed( + torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to( + speech.device + ) + ).repeat(speech.size(0), 1, 1) + + use_itn = kwargs.get("use_itn", False) + textnorm = kwargs.get("text_norm", None) + if textnorm is None: + textnorm = "withitn" if use_itn else "woitn" + textnorm_query = self.embed( + torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) + ).repeat(speech.size(0), 1, 1) + speech = torch.cat((textnorm_query, speech), dim=1) + speech_lengths += 1 + + event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( + speech.size(0), 1, 1 + ) + input_query = torch.cat((language_query, event_emo_query), dim=1) + speech = torch.cat((input_query, speech), dim=1) + speech_lengths += 3 + + # Encoder + encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # c. Passed the encoder result and the beam search + ctc_logits = self.ctc.log_softmax(encoder_out) + + results = [] + b, n, d = encoder_out.size() + if isinstance(key[0], (list, tuple)): + key = key[0] + if len(key) < b: + key = key * b + for i in range(b): + x = ctc_logits[i, : encoder_out_lens[i].item(), :] + yseq = x.argmax(dim=-1) + yseq = torch.unique_consecutive(yseq, dim=-1) + + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"1best_recog"] + + mask = yseq != self.blank_id + token_int = yseq[mask].tolist() + + # Change integer-ids to tokens + text = tokenizer.decode(token_int) + + result_i = {"key": key[i], "text": text} + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["text"][key[i]] = text + + return results, meta_data + + def export(self, **kwargs): + from .export_meta import export_rebuild_model + + if "max_seq_len" not in kwargs: + kwargs["max_seq_len"] = 512 + models = export_rebuild_model(model=self, **kwargs) + return models diff --git a/funasr/version.txt b/funasr/version.txt index fa7e3ca51..a7a8343f1 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -1.0.29 \ No newline at end of file +1.0.30 \ No newline at end of file