Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed

merge
This commit is contained in:
游雁 2024-05-15 19:02:48 +08:00
commit 67733a2a92
45 changed files with 318 additions and 87 deletions

View File

@ -28,6 +28,7 @@
<a name="whats-new"></a>
## What's new:
- 2024/05/15: Offline File Transcription Service 4.5, Offline File Transcription Service of English 1.6Real-time Transcription Service 1.10 releasedadapting to FunASR 1.0 model structure([docs](runtime/readme.md))
- 2024/03/05Added the Qwen-Audio and Qwen-Audio-Chat large-scale audio-text multimodal models, which have topped multiple audio domain leaderboards. These models support speech dialogue, [usage](examples/industrial_data_pretraining/qwen_audio).
- 2024/03/05Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
- 2024/03/05: Offline File Transcription Service 4.4, Offline File Transcription Service of English 1.5Real-time Transcription Service 1.9 releaseddocker image supports ARM64 platform, update modelscope([docs](runtime/readme.md))

View File

@ -29,6 +29,7 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
<a name="最新动态"></a>
## 最新动态
- 2024/05/15: 中文离线文件转写服务 4.5、英文离线文件转写服务 1.6、中文实时语音听写服务 1.10 发布适配FunASR 1.0模型结构;详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/03/05新增加Qwen-Audio与Qwen-Audio-Chat音频文本模态大模型在多个音频领域测试榜单刷榜中支持语音对话详细用法见 [示例](examples/industrial_data_pretraining/qwen_audio)。
- 2024/03/05新增加Whisper-large-v3模型支持多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。
- 2024/03/05: 中文离线文件转写服务 4.4、英文离线文件转写服务 1.5、中文实时语音听写服务 1.9 发布docker镜像支持arm64平台升级modelscope版本详细信息参阅([部署文档](runtime/readme_cn.md))

View File

@ -10,7 +10,7 @@ def download_model(**kwargs):
if hub == "ms":
kwargs = download_from_ms(**kwargs)
elif hub == "hf":
pass
kwargs = download_from_hf(**kwargs)
elif hub == "openai":
model_or_path = kwargs.get("model")
if os.path.exists(model_or_path):
@ -87,6 +87,67 @@ def download_from_ms(**kwargs):
return kwargs
def download_from_hf(**kwargs):
model_or_path = kwargs.get("model")
if model_or_path in name_maps_hf:
model_or_path = name_maps_hf[model_or_path]
model_revision = kwargs.get("model_revision", "master")
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
try:
model_or_path = get_or_download_model_dir_hf(
model_or_path,
model_revision,
is_training=kwargs.get("is_training"),
check_latest=kwargs.get("check_latest", True),
)
except Exception as e:
print(f"Download: {model_or_path} failed!: {e}")
kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
conf_json = json.load(f)
cfg = {}
if "file_path_metas" in conf_json:
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
cfg.update(kwargs)
if "config" in cfg:
config = OmegaConf.load(cfg["config"])
kwargs = OmegaConf.merge(config, cfg)
kwargs["model"] = config["model"]
elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
os.path.join(model_or_path, "model.pt")
):
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
kwargs["model"] = config["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
if isinstance(kwargs, DictConfig):
kwargs = OmegaConf.to_container(kwargs, resolve=True)
if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
requirements = os.path.join(model_or_path, "requirements.txt")
print(f"Detect model requirements, begin to install it: {requirements}")
from funasr.utils.install_model_requirements import install_requirements
install_requirements(requirements)
return kwargs
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
if isinstance(file_path_metas, dict):
@ -136,3 +197,22 @@ def get_or_download_model_dir(
model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
)
return model_cache_dir
def get_or_download_model_dir_hf(
model,
model_revision=None,
is_training=False,
check_latest=True,
):
"""Get local model directory or download model if necessary.
Args:
model (str): model id or path to local model directory.
model_revision (str, optional): model version number.
:param is_training:
"""
from huggingface_hub import snapshot_download
model_cache_dir = snapshot_download(model)
return model_cache_dir

View File

@ -14,7 +14,9 @@ name_maps_ms = {
"Qwen-Audio": "Qwen/Qwen-Audio",
}
name_maps_hf = {}
name_maps_hf = {
"": "",
}
name_maps_openai = {
"Whisper-tiny.en": "tiny.en",

View File

@ -20,6 +20,7 @@ def main():
args = parser.parse_args()
model_dir = args.model_name
output_dir = args.model_name
if not Path(args.model_name).exists():
from modelscope.hub.snapshot_download import snapshot_download
@ -27,6 +28,7 @@ def main():
model_dir = snapshot_download(
args.model_name, cache_dir=args.export_dir, revision=args.model_revision
)
output_dir = os.path.join(args.export_dir, args.model_name)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
@ -37,15 +39,13 @@ def main():
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.bin.export_model import ModelExport
from funasr import AutoModel
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=True,
device="cpu",
quant=args.quantize,
)
export_model.export(model_dir)
export_model = AutoModel(model=args.model_name, output_dir=output_dir)
export_model.export(
quantize=args.quantize,
type=args.type,
)
if __name__ == "__main__":

View File

@ -249,10 +249,17 @@ class Emotion2vec(torch.nn.Module):
if self.proj:
x = x.mean(dim=1)
x = self.proj(x)
for idx, lab in enumerate(labels):
x[:,idx] = -np.inf if lab.startswith("unuse") else x[:,idx]
x = torch.softmax(x, dim=-1)
scores = x[0].tolist()
result_i = {"key": key[i], "labels": labels, "scores": scores}
select_label = [lb for lb in labels if not lb.startswith("unuse")]
select_score = [scores[idx] for idx, lb in enumerate(labels) if not lb.startswith("unuse")]
# result_i = {"key": key[i], "labels": labels, "scores": scores}
result_i = {"key": key[i], "labels": select_label, "scores": select_score}
if extract_embedding:
result_i["feats"] = feats
results.append(result_i)

View File

@ -63,8 +63,8 @@ def detect_language(
else:
x = x.to(mel.device)
# FIX(funasr): sense vocie
# logits = model.logits(x[:, :-1], mel)[:, -1]
logits = model.logits(x[:, :], mel)[:, -1]
logits = model.logits(x[:, :-1], mel)[:, -1]
# logits = model.logits(x[:, :], mel)[:, -1]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)

View File

@ -1 +1 @@
1.0.26
1.0.27

View File

@ -12,6 +12,7 @@ This document serves as a development guide for the FunASR offline file transcri
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|--------------|
| 2024.05.15 | Adapting to FunASR 1.0 model structure | funasr-runtime-sdk-cpu-0.4.5 | 058b9882ae67 |
| 2024.03.05 | docker image supports ARM64 platform, update modelscope | funasr-runtime-sdk-cpu-0.4.4 | 2dc87b86dc49 |
| 2024.01.25 | Optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage; memory leak optimization| funasr-runtime-sdk-cpu-0.4.2 | befdc7b179ed |
| 2024.01.08 | optimized format sentence-level timestamps | funasr-runtime-sdk-cpu-0.4.1 | 0250f8ef981b |
@ -34,9 +35,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
### Pulling and launching images
Use the following command to pull and launch the Docker image for the FunASR runtime-SDK:
```shell
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
```
Introduction to command parameters:

View File

@ -6,6 +6,7 @@ This document serves as a development guide for the FunASR offline file transcri
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|------------|-----------------------------------------|---------------------------------|--------------|
| 2024.05.15 | Adapting to FunASR 1.0 model structure | funasr-runtime-sdk-en-cpu-0.1.6 | 84d781d07997 |
| 2024.03.05 | docker image supports ARM64 platform, update modelscope | funasr-runtime-sdk-en-cpu-0.1.5 | 7cca2abc5901 |
| 2024.01.25 | Optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage; memory leak optimization| funasr-runtime-sdk-en-cpu-0.1.3 | c00f9ce7a195 |
| 2024.01.03 | fixed known crash issues as well as memory leak problems | funasr-runtime-sdk-en-cpu-0.1.2 | 0cdd9f4a4bb5 |
@ -24,9 +25,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
### Pulling and launching images
Use the following command to pull and launch the Docker image for the FunASR runtime-SDK:
```shell
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
sudo docker run -p 10097:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
sudo docker run -p 10097:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
```
Introduction to command parameters:
```text

View File

@ -6,6 +6,7 @@ FunASR提供可一键本地或者云端服务器部署的英文离线文件转
| 时间 | 详情 | 镜像版本 | 镜像ID |
|------------|---------------|---------------------------------|--------------|
| 2024.05.15 | 适配FunASR 1.0模型结构 | funasr-runtime-sdk-en-cpu-0.1.6 | 84d781d07997 |
| 2024.03.05 | docker镜像支持arm64平台升级modelscope版本 | funasr-runtime-sdk-en-cpu-0.1.5 | 7cca2abc5901 |
| 2024.01.25 | 优化vad数据处理方式大幅降低峰值内存占用内存泄漏优化 | funasr-runtime-sdk-en-cpu-0.1.3 | c00f9ce7a195 |
| 2024.01.03 | 修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-en-cpu-0.1.2 | 0cdd9f4a4bb5 |
@ -39,11 +40,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
通过下述命令拉取并启动FunASR runtime-SDK的docker镜像
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10097:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
```
### 服务端启动

View File

@ -10,6 +10,7 @@ FunASR离线文件转写软件包提供了一款功能强大的语音离线
| 时间 | 详情 | 镜像版本 | 镜像ID |
|------------|---------------------------------------------------|------------------------------|--------------|
| 2024.05.15 | 适配FunASR 1.0模型结构 | funasr-runtime-sdk-cpu-0.4.5 | 058b9882ae67 |
| 2024.03.05 | docker镜像支持arm64平台升级modelscope版本 | funasr-runtime-sdk-cpu-0.4.4 | 2dc87b86dc49 |
| 2024.01.25 | 优化vad数据处理方式大幅降低峰值内存占用内存泄漏优化| funasr-runtime-sdk-cpu-0.4.2 | befdc7b179ed |
| 2024.01.08 | 优化句子级时间戳json格式 | funasr-runtime-sdk-cpu-0.4.1 | 0250f8ef981b |
@ -48,11 +49,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10095:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
```
### 服务端启动

View File

@ -8,6 +8,7 @@ FunASR Real-time Speech Recognition Software Package integrates real-time versio
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|------------|-------------------------------------------------------------------------------------|-------------------------------------|--------------|
| 2024.05.15 | Adapting to FunASR 1.0 model structure | funasr-runtime-sdk-online-cpu-0.1.10 | 1c2adfcff84d |
| 2024.03.05 | docker image supports ARM64 platform, update modelscope | funasr-runtime-sdk-online-cpu-0.1.9 | 4a875e08c7a2 |
| 2024.01.25 | Optimization of the client-side | funasr-runtime-sdk-online-cpu-0.1.7 | 2aa23805572e |
| 2024.01.03 | The 2pass-offline mode supports Ngram language model decoding and WFST hotwords, while also addressing known crash issues and memory leak problems | funasr-runtime-sdk-online-cpu-0.1.6 | f99925110d27 |
@ -29,9 +30,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
### Pull Docker Image
Use the following command to pull and start the FunASR software package docker image:
```shell
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10096:10095 -it --privileged=true -v $PWD/funasr-runtime-resources/models:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
sudo docker run -p 10096:10095 -it --privileged=true -v $PWD/funasr-runtime-resources/models:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
```
### Launching the Server

View File

@ -12,6 +12,7 @@ FunASR实时语音听写软件包集成了实时版本的语音端点检测
| 时间 | 详情 | 镜像版本 | 镜像ID |
|:-----------|:----------------------------------|--------------------------------------|--------------|
| 2024.05.15 | 适配FunASR 1.0模型结构 | funasr-runtime-sdk-online-cpu-0.1.10 | 1c2adfcff84d |
| 2024.03.05 | docker镜像支持arm64平台升级modelscope版本 | funasr-runtime-sdk-online-cpu-0.1.9 | 4a875e08c7a2 |
| 2024.01.25 | 客户端优化| funasr-runtime-sdk-online-cpu-0.1.7 | 2aa23805572e |
| 2024.01.03 | 2pass-offline模式支持Ngram语言模型解码、wfst热词同时修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-online-cpu-0.1.6 | f99925110d27 |
@ -38,11 +39,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10096:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
```
### 服务端启动

View File

@ -1,7 +1,7 @@
DOCKER:
funasr-runtime-sdk-en-cpu-0.1.6
funasr-runtime-sdk-en-cpu-0.1.5
funasr-runtime-sdk-en-cpu-0.1.4
funasr-runtime-sdk-en-cpu-0.1.3
DEFAULT_ASR_MODEL:
damo/speech_paraformer-large_asr_nat-en-16k-common-vocab10020-onnx
DEFAULT_VAD_MODEL:

View File

@ -1,5 +1,5 @@
DOCKER:
funasr-runtime-sdk-cpu-0.4.4
funasr-runtime-sdk-cpu-0.4.5
funasr-runtime-sdk-cpu-0.3.0
funasr-runtime-sdk-cpu-0.2.2
DEFAULT_ASR_MODEL:

View File

@ -1,7 +1,7 @@
DOCKER:
funasr-runtime-sdk-online-cpu-0.1.10
funasr-runtime-sdk-online-cpu-0.1.9
funasr-runtime-sdk-online-cpu-0.1.8
funasr-runtime-sdk-online-cpu-0.1.7
DEFAULT_ASR_MODEL:
damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx

View File

@ -18,6 +18,17 @@ else()
message("Little endian system")
endif()
# json
include(FetchContent)
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
FetchContent_Declare(json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
)
FetchContent_MakeAvailable(json)
endif()
# for onnxruntime
IF(WIN32)
file(REMOVE ${PROJECT_SOURCE_DIR}/third_party/glog/src/config.h
@ -36,6 +47,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include/limonp/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)

View File

@ -49,13 +49,14 @@ namespace funasr {
// hotword embedding compile model
#define MODEL_EB_NAME "model_eb.onnx"
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "vad.mvn"
#define VAD_CONFIG_NAME "vad.yaml"
#define VAD_CMVN_NAME "am.mvn"
#define VAD_CONFIG_NAME "config.yaml"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define LM_CONFIG_NAME "config.yaml"
#define PUNC_CONFIG_NAME "punc.yaml"
#define PUNC_CONFIG_NAME "config.yaml"
#define MODEL_SEG_DICT "seg_dict"
#define TOKEN_PATH "tokens.json"
#define HOTWORD "hotword"
// #define NN_HOTWORD "nn-hotword"

View File

@ -12,9 +12,9 @@ class Model {
virtual void StartUtterance() = 0;
virtual void EndUtterance() = 0;
virtual void Reset() = 0;
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
virtual void InitFstDecoder(){};
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};

View File

@ -11,7 +11,7 @@ namespace funasr {
class PuncModel {
public:
virtual ~PuncModel(){};
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num)=0;
virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";};
virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache, std::string language="zh-cn"){return "";};
};

View File

@ -11,7 +11,7 @@ CTTransformerOnline::CTTransformerOnline()
{
}
void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num){
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options.DisableCpuMemArena();
@ -43,7 +43,7 @@ void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::str
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
m_tokenizer.OpenYaml(punc_config.c_str());
m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str());
}
CTTransformerOnline::~CTTransformerOnline()

View File

@ -26,7 +26,7 @@ private:
public:
CTTransformerOnline();
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num);
~CTTransformerOnline();
vector<int> Infer(vector<int32_t> input_data, int nCacheSize);
string AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language="zh-cn");

View File

@ -11,7 +11,7 @@ CTTransformer::CTTransformer()
{
}
void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num){
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options.DisableCpuMemArena();
@ -39,7 +39,7 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
m_tokenizer.OpenYaml(punc_config.c_str());
m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str());
m_tokenizer.JiebaInit(punc_config);
}

View File

@ -26,7 +26,7 @@ private:
public:
CTTransformer();
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num);
~CTTransformer();
vector<int> Infer(vector<int32_t> input_data);
string AddPunc(const char* sz_input, std::string language="zh-cn");

View File

@ -30,7 +30,7 @@ void FsmnVad::LoadConfigFromYaml(const char* filename){
try{
YAML::Node frontend_conf = config["frontend_conf"];
YAML::Node post_conf = config["vad_post_conf"];
YAML::Node post_conf = config["model_conf"];
this->vad_sample_rate_ = frontend_conf["fs"].as<int>();
this->vad_silence_duration_ = post_conf["max_end_silence_time"].as<int>();

View File

@ -8,6 +8,7 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
string am_model_path;
string am_cmvn_path;
string am_config_path;
string token_path;
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
@ -15,10 +16,11 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
Model *mm;
mm = new Paraformer();
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
return mm;
}else if(type == ASR_ONLINE){
// online
@ -26,6 +28,7 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
string de_model_path;
string am_cmvn_path;
string am_config_path;
string token_path;
en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME);
de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME);
@ -35,10 +38,11 @@ Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_nu
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
Model *mm;
mm = new Paraformer();
mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
return mm;
}else{
LOG(ERROR)<<"Wrong ASR_TYPE : " << type;

View File

@ -32,6 +32,7 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
string am_model_path;
string am_cmvn_path;
string am_config_path;
string token_path;
string hw_compile_model_path;
string seg_dict_path;
@ -57,8 +58,9 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
}
// Lm resource
@ -79,20 +81,23 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
if(model_path.find(PUNC_DIR) != model_path.end()){
string punc_model_path;
string punc_config_path;
string token_path;
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
token_path = PathAppend(model_path.at(PUNC_DIR), TOKEN_PATH);
if (access(punc_model_path.c_str(), F_OK) != 0 ||
access(punc_config_path.c_str(), F_OK) != 0 )
access(punc_config_path.c_str(), F_OK) != 0 ||
access(token_path.c_str(), F_OK) != 0)
{
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
}else{
punc_handle = make_unique<CTTransformer>();
punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
punc_handle->InitPunc(punc_model_path, punc_config_path, token_path, thread_num);
use_punc = true;
}
}

View File

@ -18,7 +18,7 @@ Paraformer::Paraformer()
}
// offline
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
LoadConfigFromYaml(am_config.c_str());
// knf options
fbank_opts_.frame_opts.dither = 0;
@ -65,13 +65,13 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
m_szInputNames.push_back(item.c_str());
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
vocab = new Vocab(am_config.c_str());
phone_set_ = new PhoneSet(am_config.c_str());
vocab = new Vocab(token_file.c_str());
phone_set_ = new PhoneSet(token_file.c_str());
LoadCmvn(am_cmvn.c_str());
}
// online
void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
LoadOnlineConfigFromYaml(am_config.c_str());
// knf options
@ -143,15 +143,15 @@ void Paraformer::InitAsr(const std::string &en_model, const std::string &de_mode
for (auto& item : de_strOutputNames)
de_szOutputNames_.push_back(item.c_str());
vocab = new Vocab(am_config.c_str());
phone_set_ = new PhoneSet(am_config.c_str());
vocab = new Vocab(token_file.c_str());
phone_set_ = new PhoneSet(token_file.c_str());
LoadCmvn(am_cmvn.c_str());
}
// 2pass
void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
// online
InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
InitAsr(en_model, de_model, am_cmvn, am_config, token_file, thread_num);
// offline
try {

View File

@ -42,11 +42,11 @@ namespace funasr {
public:
Paraformer();
~Paraformer();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// online
void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// 2pass
void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
void InitHwCompiler(const std::string &hw_model, int thread_num);
void InitSegDict(const std::string &seg_dict_model);
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);

View File

@ -13,7 +13,7 @@ using namespace std;
namespace funasr {
PhoneSet::PhoneSet(const char *filename) {
ifstream in(filename);
LoadPhoneSetFromYaml(filename);
LoadPhoneSetFromJson(filename);
}
PhoneSet::~PhoneSet()
{
@ -35,6 +35,25 @@ void PhoneSet::LoadPhoneSetFromYaml(const char* filename) {
}
}
void PhoneSet::LoadPhoneSetFromJson(const char* filename) {
nlohmann::json json_array;
std::ifstream file(filename);
if (file.is_open()) {
file >> json_array;
file.close();
} else {
LOG(INFO) << "Error loading token file, token file error or not exist.";
exit(-1);
}
int id = 0;
for (const auto& element : json_array) {
phone_.push_back(element);
phn2Id_.emplace(element, id);
id++;
}
}
int PhoneSet::Size() const {
return phone_.size();
}

View File

@ -5,6 +5,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include "nlohmann/json.hpp"
#define UNIT_BEG_SIL_SYMBOL "<s>"
#define UNIT_END_SIL_SYMBOL "</s>"
#define UNIT_BLK_SYMBOL "<blank>"
@ -28,6 +29,7 @@ class PhoneSet {
vector<string> phone_;
unordered_map<string, int> phn2Id_;
void LoadPhoneSetFromYaml(const char* filename);
void LoadPhoneSetFromJson(const char* filename);
};
} // namespace funasr

View File

@ -14,14 +14,16 @@ PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int t
}
string punc_model_path;
string punc_config_path;
string token_file;
punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME);
token_file = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
mm->InitPunc(punc_model_path, punc_config_path, thread_num);
mm->InitPunc(punc_model_path, punc_config_path, token_file, thread_num);
return mm;
}

View File

@ -127,6 +127,61 @@ bool CTokenizer::OpenYaml(const char* sz_yamlfile)
return m_ready;
}
bool CTokenizer::OpenYaml(const char* sz_yamlfile, const char* token_file)
{
YAML::Node m_Config;
try{
m_Config = YAML::LoadFile(sz_yamlfile);
}catch(exception const &e){
LOG(INFO) << "Error loading file, yaml file error or not exist.";
exit(-1);
}
try
{
YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
if (conf_seg_jieba.IsDefined()){
seg_jieba = conf_seg_jieba.as<bool>();
}
auto Puncs = m_Config["model_conf"]["punc_list"];
if (Puncs.IsSequence())
{
for (size_t i = 0; i < Puncs.size(); ++i)
{
if (Puncs[i].IsScalar())
{
m_id2punc.push_back(Puncs[i].as<string>());
m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
}
}
}
nlohmann::json json_array;
std::ifstream file(token_file);
if (file.is_open()) {
file >> json_array;
file.close();
} else {
LOG(INFO) << "Error loading token file, token file error or not exist.";
return false;
}
int i = 0;
for (const auto& element : json_array) {
m_id2token.push_back(element);
m_token2id[element] = i;
i++;
}
}
catch (YAML::BadFile& e) {
LOG(ERROR) << "Read error!";
return false;
}
m_ready = true;
return m_ready;
}
vector<string> CTokenizer::Id2String(vector<int> input)
{
vector<string> result;

View File

@ -8,6 +8,7 @@
#include "cppjieba/DictTrie.hpp"
#include "cppjieba/HMMModel.hpp"
#include "cppjieba/Jieba.hpp"
#include "nlohmann/json.hpp"
namespace funasr {
class CTokenizer {
@ -27,6 +28,7 @@ public:
CTokenizer();
~CTokenizer();
bool OpenYaml(const char* sz_yamlfile);
bool OpenYaml(const char* sz_yamlfile, const char* token_file);
void ReadYaml(const YAML::Node& node);
vector<string> Id2String(vector<int> input);
vector<int> String2Ids(vector<string> input);

View File

@ -35,6 +35,7 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
string de_model_path;
string am_cmvn_path;
string am_config_path;
string token_path;
string hw_compile_model_path;
string seg_dict_path;
@ -60,8 +61,9 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
}
am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
}else{
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
exit(-1);
@ -85,20 +87,23 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
if(model_path.find(PUNC_DIR) != model_path.end()){
string punc_model_path;
string punc_config_path;
string token_path;
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
token_path = PathAppend(model_path.at(PUNC_DIR), TOKEN_PATH);
if (access(punc_model_path.c_str(), F_OK) != 0 ||
access(punc_config_path.c_str(), F_OK) != 0 )
access(punc_config_path.c_str(), F_OK) != 0 ||
access(token_path.c_str(), F_OK) != 0)
{
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
}else{
punc_online_handle = make_unique<CTTransformerOnline>();
punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
punc_online_handle->InitPunc(punc_model_path, punc_config_path, token_path, thread_num);
use_punc = true;
}
}

View File

@ -14,7 +14,7 @@ namespace funasr {
Vocab::Vocab(const char *filename)
{
ifstream in(filename);
LoadVocabFromYaml(filename);
LoadVocabFromJson(filename);
}
Vocab::Vocab(const char *filename, const char *lex_file)
{
@ -43,6 +43,25 @@ void Vocab::LoadVocabFromYaml(const char* filename){
}
}
void Vocab::LoadVocabFromJson(const char* filename){
nlohmann::json json_array;
std::ifstream file(filename);
if (file.is_open()) {
file >> json_array;
file.close();
} else {
LOG(INFO) << "Error loading token file, token file error or not exist.";
exit(-1);
}
int i = 0;
for (const auto& element : json_array) {
vocab.push_back(element);
token_id[element] = i;
i++;
}
}
void Vocab::LoadLex(const char* filename){
std::ifstream file(filename);
std::string line;

View File

@ -6,6 +6,7 @@
#include <string>
#include <vector>
#include <map>
#include "nlohmann/json.hpp"
using namespace std;
namespace funasr {
@ -16,6 +17,7 @@ class Vocab {
std::map<string, string> lex_map;
bool IsEnglish(string ch);
void LoadVocabFromYaml(const char* filename);
void LoadVocabFromJson(const char* filename);
void LoadLex(const char* filename);
public:

View File

@ -117,7 +117,7 @@ async def api_recognition(audio: UploadFile = File(..., description="audio file"
for sentence in rec_result["sentence_info"]:
# 每句话的时间戳
sentences.append(
{"text": sentence["text"], "start": sentence["start"], "end": sentence["start"]}
{"text": sentence["text"], "start": sentence["start"], "end": sentence["end"]}
)
ret = {"text": text, "sentences": sentences, "code": 0}
logger.info(f"识别结果:{ret}")

View File

@ -47,11 +47,11 @@ Use the following command to pull and launch the FunASR software package Docker
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10096:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
```
###### Server Start
@ -93,11 +93,11 @@ Use the following command to pull and launch the FunASR software package Docker
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10095:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
```
###### Server Start

View File

@ -48,11 +48,11 @@ sudo bash install_docker.sh
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10096:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
```
###### 服务端启动
@ -92,11 +92,11 @@ python3 funasr_wss_client.py --host "127.0.0.1" --port 10096 --mode 2pass
```shell
sudo docker pull \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10095:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
```
###### 服务端启动

View File

@ -17,6 +17,7 @@ Currently, the FunASR runtime-SDK supports the deployment of file transcription
To meet the needs of different users, we have prepared different tutorials with text and images for both novice and advanced developers.
### Whats-new
- 2024/05/15: Adapting to FunASR 1.0 model structure, docker image version funasr-runtime-sdk-en-cpu-0.1.6 (84d781d07997).
- 2024/03/05: docker image supports ARM64 platform, update modelscope, docker image version funasr-runtime-sdk-en-cpu-0.1.5 (7cca2abc5901).
- 2024/01/25: Optimized the VAD (Voice Activity Detection) data processing method,significantly reducing peak memory usage,memory leak optimization, docker image version funasr-runtime-sdk-en-cpu-0.1.3 (c00f9ce7a195).
- 2024/01/03: Fixed known crash issues as well as memory leak problems, docker image version funasr-runtime-sdk-en-cpu-0.1.2 (0cdd9f4a4bb5).
@ -42,6 +43,7 @@ The FunASR real-time speech-to-text service software package not only performs r
In order to meet the needs of different users for different scenarios, different tutorials are prepared:
### Whats-new
- 2024/05/15: Real-time Transcription Service 1.10 releasedadapting to FunASR 1.0 model structure, docker image version funasr-runtime-sdk-online-cpu-0.1.10 (1c2adfcff84d)
- 2024/03/05: Real-time Transcription Service 1.9 releaseddocker image supports ARM64 platform, update modelscope, docker image version funasr-runtime-sdk-online-cpu-0.1.9 (4a875e08c7a2)
- 2024/01/25: Real-time Transcription Service 1.7 releasedoptimization of the client-side, docker image version funasr-runtime-sdk-online-cpu-0.1.7 (2aa23805572e)
- 2024/01/03: Real-time Transcription Service 1.6 releasedThe 2pass-offline mode supports Ngram language model decoding and WFST hotwords, while also addressing known crash issues and memory leak problems, docker image version funasr-runtime-sdk-online-cpu-0.1.6 (f99925110d27)
@ -72,6 +74,7 @@ Currently, the FunASR runtime-SDK supports the deployment of file transcription
To meet the needs of different users, we have prepared different tutorials with text and images for both novice and advanced developers.
### Whats-new
- 2024/05/15: File Transcription Service 4.5 released, adapting to FunASR 1.0 model structure, docker image version funasr-runtime-sdk-cpu-0.4.5 (058b9882ae67)
- 2024/03/05: File Transcription Service 4.4 released, docker image supports ARM64 platform, update modelscope, docker image version funasr-runtime-sdk-cpu-0.4.4 (2dc87b86dc49)
- 2024/01/25: File Transcription Service 4.2 released, optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage, memory leak optimization, docker image version funasr-runtime-sdk-cpu-0.4.2 (befdc7b179ed)
- 2024/01/08: File Transcription Service 4.1 released, optimized format sentence-level timestamps, docker image version funasr-runtime-sdk-cpu-0.4.1 (0250f8ef981b)

View File

@ -19,6 +19,7 @@ FunASR是由阿里巴巴通义实验室语音团队开源的一款语音识别
为了支持不同用户的需求,针对不同场景,准备了不同的图文教程:
### 最新动态
- 2024/05/15: 英文离线文件转写服务 1.6 发布适配FunASR 1.0模型结构dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.6 (84d781d07997)
- 2024/03/05: 英文离线文件转写服务 1.5 发布docker镜像支持arm64平台升级modelscope版本dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.5 (7cca2abc5901)
- 2024/01/25: 英文离线文件转写服务 1.3 发布优化vad数据处理方式大幅降低峰值内存占用内存泄漏优化dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.3 (c00f9ce7a195)
- 2024/01/03: 英文离线文件转写服务 1.2 发布修复已知的crash问题及内存泄漏问题dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.2 (0cdd9f4a4bb5)
@ -36,6 +37,7 @@ FunASR实时语音听写服务软件包既可以实时地进行语音转文
为了支持不同用户的需求,针对不同场景,准备了不同的图文教程:
### 最新动态
- 2024/05/15: 中文实时语音听写服务 1.10 发布适配FunASR 1.0模型结构dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.10 (1c2adfcff84d)
- 2024/03/05: 中文实时语音听写服务 1.9 发布docker镜像支持arm64平台升级modelscope版本dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.9 (4a875e08c7a2)
- 2024/01/25: 中文实时语音听写服务 1.7 发布客户端优化dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.7 (2aa23805572e)
- 2024/01/03: 中文实时语音听写服务 1.6 发布2pass-offline模式支持Ngram语言模型解码、wfst热词同时修复已知的crash问题及内存泄漏问题dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.6 (f99925110d27)
@ -58,6 +60,7 @@ FunASR实时语音听写服务软件包既可以实时地进行语音转文
为了支持不同用户的需求,针对不同场景,准备了不同的图文教程:
### 最新动态
- 2024/05/15: 中文离线文件转写服务 4.5 发布适配FunASR 1.0模型结构dokcer镜像版本funasr-runtime-sdk-cpu-0.4.5 (058b9882ae67)
- 2024/03/05: 中文离线文件转写服务 4.4 发布docker镜像支持arm64平台升级modelscope版本dokcer镜像版本funasr-runtime-sdk-cpu-0.4.4 (2dc87b86dc49)
- 2024/01/25: 中文离线文件转写服务 4.2 发布优化vad数据处理方式大幅降低峰值内存占用内存泄漏优化dokcer镜像版本funasr-runtime-sdk-cpu-0.4.2 (befdc7b179ed)
- 2024/01/08: 中文离线文件转写服务 4.1 发布优化句子级时间戳json格式dokcer镜像版本funasr-runtime-sdk-cpu-0.4.1 (0250f8ef981b)

View File

@ -55,11 +55,11 @@ int main(int argc, char* argv[]) {
TCLAP::ValueArg<std::string> offline_model_revision(
"", "offline-model-revision", "ASR offline model revision", false,
"v1.2.1", "string");
"v2.0.4", "string");
TCLAP::ValueArg<std::string> online_model_revision(
"", "online-model-revision", "ASR online model revision", false,
"v1.0.6", "string");
"v2.0.4", "string");
TCLAP::ValueArg<std::string> quantize(
"", QUANTIZE,
@ -73,7 +73,7 @@ int main(int argc, char* argv[]) {
"model_quant.onnx, vad.yaml, vad.mvn",
false, "damo/speech_fsmn_vad_zh-cn-16k-common-onnx", "string");
TCLAP::ValueArg<std::string> vad_revision(
"", "vad-revision", "VAD model revision", false, "v1.2.0", "string");
"", "vad-revision", "VAD model revision", false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
@ -85,7 +85,7 @@ int main(int argc, char* argv[]) {
"model_quant.onnx, punc.yaml",
false, "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx", "string");
TCLAP::ValueArg<std::string> punc_revision(
"", "punc-revision", "PUNC model revision", false, "v1.0.2", "string");
"", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> punc_quant(
"", PUNC_QUANT,
"true (Default), load the model of model_quant.onnx in punc_dir. If "
@ -262,17 +262,17 @@ int main(int argc, char* argv[]) {
size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["offline-model-revision"]="v1.2.4";
model_path["offline-model-revision"]="v2.0.4";
}
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["offline-model-revision"]="v1.0.5";
model_path["offline-model-revision"]="v2.0.5";
}
found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
if (found != std::string::npos) {
model_path["model-revision"]="v1.0.0";
model_path["model-revision"]="v2.0.4";
s_itn_path="";
s_lm_path="";
}

View File

@ -50,7 +50,7 @@ int main(int argc, char* argv[]) {
TCLAP::ValueArg<std::string> model_revision(
"", "model-revision",
"ASR model revision",
false, "v1.2.1", "string");
false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> quantize(
"", QUANTIZE,
"true (Default), load the model of model_quant.onnx in model_dir. If set "
@ -63,7 +63,7 @@ int main(int argc, char* argv[]) {
TCLAP::ValueArg<std::string> vad_revision(
"", "vad-revision",
"VAD model revision",
false, "v1.2.0", "string");
false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
@ -77,7 +77,7 @@ int main(int argc, char* argv[]) {
TCLAP::ValueArg<std::string> punc_revision(
"", "punc-revision",
"PUNC model revision",
false, "v1.1.7", "string");
false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> punc_quant(
"", PUNC_QUANT,
"true (Default), load the model of model_quant.onnx in punc_dir. If set "
@ -233,17 +233,17 @@ int main(int argc, char* argv[]) {
// modify model-revision by model name
size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["model-revision"]="v1.2.4";
model_path["model-revision"]="v2.0.4";
}
found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
if (found != std::string::npos) {
model_path["model-revision"]="v1.0.5";
model_path["model-revision"]="v2.0.5";
}
found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
if (found != std::string::npos) {
model_path["model-revision"]="v1.0.0";
model_path["model-revision"]="v2.0.4";
s_itn_path="";
s_lm_path="";
}