diff --git a/README.md b/README.md
index 8b093bcf4..ba23f3fcd 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,7 @@
## What's new:
+- 2024/05/15: Offline File Transcription Service 4.5, Offline File Transcription Service of English 1.6,Real-time Transcription Service 1.10 released,adapting to FunASR 1.0 model structure;([docs](runtime/readme.md))
- 2024/03/05:Added the Qwen-Audio and Qwen-Audio-Chat large-scale audio-text multimodal models, which have topped multiple audio domain leaderboards. These models support speech dialogue, [usage](examples/industrial_data_pretraining/qwen_audio).
- 2024/03/05:Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
- 2024/03/05: Offline File Transcription Service 4.4, Offline File Transcription Service of English 1.5,Real-time Transcription Service 1.9 released,docker image supports ARM64 platform, update modelscope;([docs](runtime/readme.md))
diff --git a/README_zh.md b/README_zh.md
index 963469a43..44f92e63b 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -29,6 +29,7 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
## 最新动态
+- 2024/05/15: 中文离线文件转写服务 4.5、英文离线文件转写服务 1.6、中文实时语音听写服务 1.10 发布,适配FunASR 1.0模型结构;详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/03/05:新增加Qwen-Audio与Qwen-Audio-Chat音频文本模态大模型,在多个音频领域测试榜单刷榜,中支持语音对话,详细用法见 [示例](examples/industrial_data_pretraining/qwen_audio)。
- 2024/03/05:新增加Whisper-large-v3模型支持,多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。
- 2024/03/05: 中文离线文件转写服务 4.4、英文离线文件转写服务 1.5、中文实时语音听写服务 1.9 发布,docker镜像支持arm64平台,升级modelscope版本;详细信息参阅([部署文档](runtime/readme_cn.md))
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 43f5b67e6..075b13118 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -10,7 +10,7 @@ def download_model(**kwargs):
if hub == "ms":
kwargs = download_from_ms(**kwargs)
elif hub == "hf":
- pass
+ kwargs = download_from_hf(**kwargs)
elif hub == "openai":
model_or_path = kwargs.get("model")
if os.path.exists(model_or_path):
@@ -87,6 +87,67 @@ def download_from_ms(**kwargs):
return kwargs
+def download_from_hf(**kwargs):
+ model_or_path = kwargs.get("model")
+ if model_or_path in name_maps_hf:
+ model_or_path = name_maps_hf[model_or_path]
+ model_revision = kwargs.get("model_revision", "master")
+ if not os.path.exists(model_or_path) and "model_path" not in kwargs:
+ try:
+ model_or_path = get_or_download_model_dir_hf(
+ model_or_path,
+ model_revision,
+ is_training=kwargs.get("is_training"),
+ check_latest=kwargs.get("check_latest", True),
+ )
+ except Exception as e:
+ print(f"Download: {model_or_path} failed!: {e}")
+
+ kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
+
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
+ with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
+ conf_json = json.load(f)
+
+ cfg = {}
+ if "file_path_metas" in conf_json:
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ cfg.update(kwargs)
+ if "config" in cfg:
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
+ os.path.join(model_or_path, "model.pt")
+ ):
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
+ kwargs = OmegaConf.merge(config, kwargs)
+ init_param = os.path.join(model_or_path, "model.pb")
+ kwargs["init_param"] = init_param
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+ kwargs["model"] = config["model"]
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+ if isinstance(kwargs, DictConfig):
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
+ if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
+ requirements = os.path.join(model_or_path, "requirements.txt")
+ print(f"Detect model requirements, begin to install it: {requirements}")
+ from funasr.utils.install_model_requirements import install_requirements
+
+ install_requirements(requirements)
+ return kwargs
+
+
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
if isinstance(file_path_metas, dict):
@@ -136,3 +197,22 @@ def get_or_download_model_dir(
model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
)
return model_cache_dir
+
+
+def get_or_download_model_dir_hf(
+ model,
+ model_revision=None,
+ is_training=False,
+ check_latest=True,
+):
+ """Get local model directory or download model if necessary.
+
+ Args:
+ model (str): model id or path to local model directory.
+ model_revision (str, optional): model version number.
+ :param is_training:
+ """
+ from huggingface_hub import snapshot_download
+
+ model_cache_dir = snapshot_download(model)
+ return model_cache_dir
diff --git a/funasr/download/name_maps_from_hub.py b/funasr/download/name_maps_from_hub.py
index 87a89fcc6..3bb25a7ca 100644
--- a/funasr/download/name_maps_from_hub.py
+++ b/funasr/download/name_maps_from_hub.py
@@ -14,7 +14,9 @@ name_maps_ms = {
"Qwen-Audio": "Qwen/Qwen-Audio",
}
-name_maps_hf = {}
+name_maps_hf = {
+ "": "",
+}
name_maps_openai = {
"Whisper-tiny.en": "tiny.en",
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 7776a7116..96c67355e 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -20,6 +20,7 @@ def main():
args = parser.parse_args()
model_dir = args.model_name
+ output_dir = args.model_name
if not Path(args.model_name).exists():
from modelscope.hub.snapshot_download import snapshot_download
@@ -27,6 +28,7 @@ def main():
model_dir = snapshot_download(
args.model_name, cache_dir=args.export_dir, revision=args.model_revision
)
+ output_dir = os.path.join(args.export_dir, args.model_name)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
@@ -37,15 +39,13 @@ def main():
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
- from funasr.bin.export_model import ModelExport
+ from funasr import AutoModel
- export_model = ModelExport(
- cache_dir=args.export_dir,
- onnx=True,
- device="cpu",
- quant=args.quantize,
- )
- export_model.export(model_dir)
+ export_model = AutoModel(model=args.model_name, output_dir=output_dir)
+ export_model.export(
+ quantize=args.quantize,
+ type=args.type,
+ )
if __name__ == "__main__":
diff --git a/funasr/models/emotion2vec/model.py b/funasr/models/emotion2vec/model.py
index 48b87160f..d18e1844c 100644
--- a/funasr/models/emotion2vec/model.py
+++ b/funasr/models/emotion2vec/model.py
@@ -249,10 +249,17 @@ class Emotion2vec(torch.nn.Module):
if self.proj:
x = x.mean(dim=1)
x = self.proj(x)
+ for idx, lab in enumerate(labels):
+ x[:,idx] = -np.inf if lab.startswith("unuse") else x[:,idx]
x = torch.softmax(x, dim=-1)
scores = x[0].tolist()
- result_i = {"key": key[i], "labels": labels, "scores": scores}
+ select_label = [lb for lb in labels if not lb.startswith("unuse")]
+ select_score = [scores[idx] for idx, lb in enumerate(labels) if not lb.startswith("unuse")]
+
+ # result_i = {"key": key[i], "labels": labels, "scores": scores}
+ result_i = {"key": key[i], "labels": select_label, "scores": select_score}
+
if extract_embedding:
result_i["feats"] = feats
results.append(result_i)
diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 609d6a607..a468efaa9 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -63,8 +63,8 @@ def detect_language(
else:
x = x.to(mel.device)
# FIX(funasr): sense vocie
- # logits = model.logits(x[:, :-1], mel)[:, -1]
- logits = model.logits(x[:, :], mel)[:, -1]
+ logits = model.logits(x[:, :-1], mel)[:, -1]
+ # logits = model.logits(x[:, :], mel)[:, -1]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
diff --git a/funasr/version.txt b/funasr/version.txt
index 7717884db..3f11ef630 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-1.0.26
\ No newline at end of file
+1.0.27
\ No newline at end of file
diff --git a/runtime/docs/SDK_advanced_guide_offline.md b/runtime/docs/SDK_advanced_guide_offline.md
index d975b53d9..799727f8f 100644
--- a/runtime/docs/SDK_advanced_guide_offline.md
+++ b/runtime/docs/SDK_advanced_guide_offline.md
@@ -12,6 +12,7 @@ This document serves as a development guide for the FunASR offline file transcri
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|--------------|
+| 2024.05.15 | Adapting to FunASR 1.0 model structure | funasr-runtime-sdk-cpu-0.4.5 | 058b9882ae67 |
| 2024.03.05 | docker image supports ARM64 platform, update modelscope | funasr-runtime-sdk-cpu-0.4.4 | 2dc87b86dc49 |
| 2024.01.25 | Optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage; memory leak optimization| funasr-runtime-sdk-cpu-0.4.2 | befdc7b179ed |
| 2024.01.08 | optimized format sentence-level timestamps | funasr-runtime-sdk-cpu-0.4.1 | 0250f8ef981b |
@@ -34,9 +35,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
### Pulling and launching images
Use the following command to pull and launch the Docker image for the FunASR runtime-SDK:
```shell
-sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
+sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
-sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
+sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
```
Introduction to command parameters:
diff --git a/runtime/docs/SDK_advanced_guide_offline_en.md b/runtime/docs/SDK_advanced_guide_offline_en.md
index 220c10c68..4f61416a3 100644
--- a/runtime/docs/SDK_advanced_guide_offline_en.md
+++ b/runtime/docs/SDK_advanced_guide_offline_en.md
@@ -6,6 +6,7 @@ This document serves as a development guide for the FunASR offline file transcri
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|------------|-----------------------------------------|---------------------------------|--------------|
+| 2024.05.15 | Adapting to FunASR 1.0 model structure | funasr-runtime-sdk-en-cpu-0.1.6 | 84d781d07997 |
| 2024.03.05 | docker image supports ARM64 platform, update modelscope | funasr-runtime-sdk-en-cpu-0.1.5 | 7cca2abc5901 |
| 2024.01.25 | Optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage; memory leak optimization| funasr-runtime-sdk-en-cpu-0.1.3 | c00f9ce7a195 |
| 2024.01.03 | fixed known crash issues as well as memory leak problems | funasr-runtime-sdk-en-cpu-0.1.2 | 0cdd9f4a4bb5 |
@@ -24,9 +25,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
### Pulling and launching images
Use the following command to pull and launch the Docker image for the FunASR runtime-SDK:
```shell
-sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
+sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
-sudo docker run -p 10097:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
+sudo docker run -p 10097:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
```
Introduction to command parameters:
```text
diff --git a/runtime/docs/SDK_advanced_guide_offline_en_zh.md b/runtime/docs/SDK_advanced_guide_offline_en_zh.md
index b1ce1ee5d..3d6534312 100644
--- a/runtime/docs/SDK_advanced_guide_offline_en_zh.md
+++ b/runtime/docs/SDK_advanced_guide_offline_en_zh.md
@@ -6,6 +6,7 @@ FunASR提供可一键本地或者云端服务器部署的英文离线文件转
| 时间 | 详情 | 镜像版本 | 镜像ID |
|------------|---------------|---------------------------------|--------------|
+| 2024.05.15 | 适配FunASR 1.0模型结构 | funasr-runtime-sdk-en-cpu-0.1.6 | 84d781d07997 |
| 2024.03.05 | docker镜像支持arm64平台,升级modelscope版本 | funasr-runtime-sdk-en-cpu-0.1.5 | 7cca2abc5901 |
| 2024.01.25 | 优化vad数据处理方式,大幅降低峰值内存占用;内存泄漏优化 | funasr-runtime-sdk-en-cpu-0.1.3 | c00f9ce7a195 |
| 2024.01.03 | 修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-en-cpu-0.1.2 | 0cdd9f4a4bb5 |
@@ -39,11 +40,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
通过下述命令拉取并启动FunASR runtime-SDK的docker镜像:
```shell
sudo docker pull \
- registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
+ registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10097:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
- registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.5
+ registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.6
```
### 服务端启动
diff --git a/runtime/docs/SDK_advanced_guide_offline_zh.md b/runtime/docs/SDK_advanced_guide_offline_zh.md
index ef4cfd23e..1cecb8881 100644
--- a/runtime/docs/SDK_advanced_guide_offline_zh.md
+++ b/runtime/docs/SDK_advanced_guide_offline_zh.md
@@ -10,6 +10,7 @@ FunASR离线文件转写软件包,提供了一款功能强大的语音离线
| 时间 | 详情 | 镜像版本 | 镜像ID |
|------------|---------------------------------------------------|------------------------------|--------------|
+| 2024.05.15 | 适配FunASR 1.0模型结构 | funasr-runtime-sdk-cpu-0.4.5 | 058b9882ae67 |
| 2024.03.05 | docker镜像支持arm64平台,升级modelscope版本 | funasr-runtime-sdk-cpu-0.4.4 | 2dc87b86dc49 |
| 2024.01.25 | 优化vad数据处理方式,大幅降低峰值内存占用;内存泄漏优化| funasr-runtime-sdk-cpu-0.4.2 | befdc7b179ed |
| 2024.01.08 | 优化句子级时间戳json格式 | funasr-runtime-sdk-cpu-0.4.1 | 0250f8ef981b |
@@ -48,11 +49,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
```shell
sudo docker pull \
- registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
+ registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10095:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
- registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.4
+ registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.5
```
### 服务端启动
diff --git a/runtime/docs/SDK_advanced_guide_online.md b/runtime/docs/SDK_advanced_guide_online.md
index 34b601c31..be9e5e8e8 100644
--- a/runtime/docs/SDK_advanced_guide_online.md
+++ b/runtime/docs/SDK_advanced_guide_online.md
@@ -8,6 +8,7 @@ FunASR Real-time Speech Recognition Software Package integrates real-time versio
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|------------|-------------------------------------------------------------------------------------|-------------------------------------|--------------|
+| 2024.05.15 | Adapting to FunASR 1.0 model structure | funasr-runtime-sdk-online-cpu-0.1.10 | 1c2adfcff84d |
| 2024.03.05 | docker image supports ARM64 platform, update modelscope | funasr-runtime-sdk-online-cpu-0.1.9 | 4a875e08c7a2 |
| 2024.01.25 | Optimization of the client-side | funasr-runtime-sdk-online-cpu-0.1.7 | 2aa23805572e |
| 2024.01.03 | The 2pass-offline mode supports Ngram language model decoding and WFST hotwords, while also addressing known crash issues and memory leak problems | funasr-runtime-sdk-online-cpu-0.1.6 | f99925110d27 |
@@ -29,9 +30,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
### Pull Docker Image
Use the following command to pull and start the FunASR software package docker image:
```shell
-sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
+sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
mkdir -p ./funasr-runtime-resources/models
-sudo docker run -p 10096:10095 -it --privileged=true -v $PWD/funasr-runtime-resources/models:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
+sudo docker run -p 10096:10095 -it --privileged=true -v $PWD/funasr-runtime-resources/models:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
```
### Launching the Server
diff --git a/runtime/docs/SDK_advanced_guide_online_zh.md b/runtime/docs/SDK_advanced_guide_online_zh.md
index 4b72e6d31..26ca4bcb3 100644
--- a/runtime/docs/SDK_advanced_guide_online_zh.md
+++ b/runtime/docs/SDK_advanced_guide_online_zh.md
@@ -12,6 +12,7 @@ FunASR实时语音听写软件包,集成了实时版本的语音端点检测
| 时间 | 详情 | 镜像版本 | 镜像ID |
|:-----------|:----------------------------------|--------------------------------------|--------------|
+| 2024.05.15 | 适配FunASR 1.0模型结构 | funasr-runtime-sdk-online-cpu-0.1.10 | 1c2adfcff84d |
| 2024.03.05 | docker镜像支持arm64平台,升级modelscope版本 | funasr-runtime-sdk-online-cpu-0.1.9 | 4a875e08c7a2 |
| 2024.01.25 | 客户端优化| funasr-runtime-sdk-online-cpu-0.1.7 | 2aa23805572e |
| 2024.01.03 | 2pass-offline模式支持Ngram语言模型解码、wfst热词,同时修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-online-cpu-0.1.6 | f99925110d27 |
@@ -38,11 +39,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
```shell
sudo docker pull \
- registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
+ registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
mkdir -p ./funasr-runtime-resources/models
sudo docker run -p 10096:10095 -it --privileged=true \
-v $PWD/funasr-runtime-resources/models:/workspace/models \
- registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.9
+ registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.10
```
### 服务端启动
diff --git a/runtime/docs/docker_offline_cpu_en_lists b/runtime/docs/docker_offline_cpu_en_lists
index 52b14a884..40e5852df 100644
--- a/runtime/docs/docker_offline_cpu_en_lists
+++ b/runtime/docs/docker_offline_cpu_en_lists
@@ -1,7 +1,7 @@
DOCKER:
+ funasr-runtime-sdk-en-cpu-0.1.6
funasr-runtime-sdk-en-cpu-0.1.5
funasr-runtime-sdk-en-cpu-0.1.4
- funasr-runtime-sdk-en-cpu-0.1.3
DEFAULT_ASR_MODEL:
damo/speech_paraformer-large_asr_nat-en-16k-common-vocab10020-onnx
DEFAULT_VAD_MODEL:
diff --git a/runtime/docs/docker_offline_cpu_zh_lists b/runtime/docs/docker_offline_cpu_zh_lists
index ccd5d95db..e3d9efcf7 100644
--- a/runtime/docs/docker_offline_cpu_zh_lists
+++ b/runtime/docs/docker_offline_cpu_zh_lists
@@ -1,5 +1,5 @@
DOCKER:
- funasr-runtime-sdk-cpu-0.4.4
+ funasr-runtime-sdk-cpu-0.4.5
funasr-runtime-sdk-cpu-0.3.0
funasr-runtime-sdk-cpu-0.2.2
DEFAULT_ASR_MODEL:
diff --git a/runtime/docs/docker_online_cpu_zh_lists b/runtime/docs/docker_online_cpu_zh_lists
index c4ac16b64..4cb5ca045 100644
--- a/runtime/docs/docker_online_cpu_zh_lists
+++ b/runtime/docs/docker_online_cpu_zh_lists
@@ -1,7 +1,7 @@
DOCKER:
+ funasr-runtime-sdk-online-cpu-0.1.10
funasr-runtime-sdk-online-cpu-0.1.9
funasr-runtime-sdk-online-cpu-0.1.8
- funasr-runtime-sdk-online-cpu-0.1.7
DEFAULT_ASR_MODEL:
damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx
diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt
index 3450be78a..d8e623e1d 100644
--- a/runtime/onnxruntime/CMakeLists.txt
+++ b/runtime/onnxruntime/CMakeLists.txt
@@ -18,6 +18,17 @@ else()
message("Little endian system")
endif()
+# json
+include(FetchContent)
+if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
+FetchContent_Declare(json
+ URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
+SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
+)
+
+FetchContent_MakeAvailable(json)
+endif()
+
# for onnxruntime
IF(WIN32)
file(REMOVE ${PROJECT_SOURCE_DIR}/third_party/glog/src/config.h
@@ -36,6 +47,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include/limonp/include)
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
+include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index 9cb1f2c96..d4edd5bc5 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -49,13 +49,14 @@ namespace funasr {
// hotword embedding compile model
#define MODEL_EB_NAME "model_eb.onnx"
#define QUANT_MODEL_NAME "model_quant.onnx"
-#define VAD_CMVN_NAME "vad.mvn"
-#define VAD_CONFIG_NAME "vad.yaml"
+#define VAD_CMVN_NAME "am.mvn"
+#define VAD_CONFIG_NAME "config.yaml"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define LM_CONFIG_NAME "config.yaml"
-#define PUNC_CONFIG_NAME "punc.yaml"
+#define PUNC_CONFIG_NAME "config.yaml"
#define MODEL_SEG_DICT "seg_dict"
+#define TOKEN_PATH "tokens.json"
#define HOTWORD "hotword"
// #define NN_HOTWORD "nn-hotword"
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index 33caec806..f5c4027d2 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -12,9 +12,9 @@ class Model {
virtual void StartUtterance() = 0;
virtual void EndUtterance() = 0;
virtual void Reset() = 0;
- virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
- virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
- virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+ virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+ virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+ virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
virtual void InitFstDecoder(){};
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
diff --git a/runtime/onnxruntime/include/punc-model.h b/runtime/onnxruntime/include/punc-model.h
index 214c7700a..3cec2c1cf 100644
--- a/runtime/onnxruntime/include/punc-model.h
+++ b/runtime/onnxruntime/include/punc-model.h
@@ -11,7 +11,7 @@ namespace funasr {
class PuncModel {
public:
virtual ~PuncModel(){};
- virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
+ virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num)=0;
virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";};
virtual std::string AddPunc(const char* sz_input, std::vector& arr_cache, std::string language="zh-cn"){return "";};
};
diff --git a/runtime/onnxruntime/src/ct-transformer-online.cpp b/runtime/onnxruntime/src/ct-transformer-online.cpp
index 4e9136ed4..92fe41e96 100644
--- a/runtime/onnxruntime/src/ct-transformer-online.cpp
+++ b/runtime/onnxruntime/src/ct-transformer-online.cpp
@@ -11,7 +11,7 @@ CTTransformerOnline::CTTransformerOnline()
{
}
-void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
+void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num){
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options.DisableCpuMemArena();
@@ -43,7 +43,7 @@ void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::str
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- m_tokenizer.OpenYaml(punc_config.c_str());
+ m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str());
}
CTTransformerOnline::~CTTransformerOnline()
diff --git a/runtime/onnxruntime/src/ct-transformer-online.h b/runtime/onnxruntime/src/ct-transformer-online.h
index ea7edb7fa..13f40a0c1 100644
--- a/runtime/onnxruntime/src/ct-transformer-online.h
+++ b/runtime/onnxruntime/src/ct-transformer-online.h
@@ -26,7 +26,7 @@ private:
public:
CTTransformerOnline();
- void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
+ void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num);
~CTTransformerOnline();
vector Infer(vector input_data, int nCacheSize);
string AddPunc(const char* sz_input, vector &arr_cache, std::string language="zh-cn");
diff --git a/runtime/onnxruntime/src/ct-transformer.cpp b/runtime/onnxruntime/src/ct-transformer.cpp
index 8f8d95310..d1a7813b1 100644
--- a/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/runtime/onnxruntime/src/ct-transformer.cpp
@@ -11,7 +11,7 @@ CTTransformer::CTTransformer()
{
}
-void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
+void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num){
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options.DisableCpuMemArena();
@@ -39,7 +39,7 @@ void CTTransformer::InitPunc(const std::string &punc_model, const std::string &p
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- m_tokenizer.OpenYaml(punc_config.c_str());
+ m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str());
m_tokenizer.JiebaInit(punc_config);
}
diff --git a/runtime/onnxruntime/src/ct-transformer.h b/runtime/onnxruntime/src/ct-transformer.h
index b33dcf55b..f38fe12a4 100644
--- a/runtime/onnxruntime/src/ct-transformer.h
+++ b/runtime/onnxruntime/src/ct-transformer.h
@@ -26,7 +26,7 @@ private:
public:
CTTransformer();
- void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
+ void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num);
~CTTransformer();
vector Infer(vector input_data);
string AddPunc(const char* sz_input, std::string language="zh-cn");
diff --git a/runtime/onnxruntime/src/fsmn-vad.cpp b/runtime/onnxruntime/src/fsmn-vad.cpp
index c83227405..42ce83b45 100644
--- a/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -30,7 +30,7 @@ void FsmnVad::LoadConfigFromYaml(const char* filename){
try{
YAML::Node frontend_conf = config["frontend_conf"];
- YAML::Node post_conf = config["vad_post_conf"];
+ YAML::Node post_conf = config["model_conf"];
this->vad_sample_rate_ = frontend_conf["fs"].as();
this->vad_silence_duration_ = post_conf["max_end_silence_time"].as();
diff --git a/runtime/onnxruntime/src/model.cpp b/runtime/onnxruntime/src/model.cpp
index 646f26029..8b5e33f65 100644
--- a/runtime/onnxruntime/src/model.cpp
+++ b/runtime/onnxruntime/src/model.cpp
@@ -8,6 +8,7 @@ Model *CreateModel(std::map& model_path, int thread_nu
string am_model_path;
string am_cmvn_path;
string am_config_path;
+ string token_path;
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
@@ -15,10 +16,11 @@ Model *CreateModel(std::map& model_path, int thread_nu
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+ token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
Model *mm;
mm = new Paraformer();
- mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+ mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
return mm;
}else if(type == ASR_ONLINE){
// online
@@ -26,6 +28,7 @@ Model *CreateModel(std::map& model_path, int thread_nu
string de_model_path;
string am_cmvn_path;
string am_config_path;
+ string token_path;
en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME);
de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME);
@@ -35,10 +38,11 @@ Model *CreateModel(std::map& model_path, int thread_nu
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+ token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
Model *mm;
mm = new Paraformer();
- mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+ mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
return mm;
}else{
LOG(ERROR)<<"Wrong ASR_TYPE : " << type;
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index ae8cf184f..7d86f9bc1 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -32,6 +32,7 @@ OfflineStream::OfflineStream(std::map& model_path, int
string am_model_path;
string am_cmvn_path;
string am_config_path;
+ string token_path;
string hw_compile_model_path;
string seg_dict_path;
@@ -57,8 +58,9 @@ OfflineStream::OfflineStream(std::map& model_path, int
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+ token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
- asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+ asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
}
// Lm resource
@@ -79,20 +81,23 @@ OfflineStream::OfflineStream(std::map& model_path, int
if(model_path.find(PUNC_DIR) != model_path.end()){
string punc_model_path;
string punc_config_path;
+ string token_path;
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+ token_path = PathAppend(model_path.at(PUNC_DIR), TOKEN_PATH);
if (access(punc_model_path.c_str(), F_OK) != 0 ||
- access(punc_config_path.c_str(), F_OK) != 0 )
+ access(punc_config_path.c_str(), F_OK) != 0 ||
+ access(token_path.c_str(), F_OK) != 0)
{
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
}else{
punc_handle = make_unique();
- punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ punc_handle->InitPunc(punc_model_path, punc_config_path, token_path, thread_num);
use_punc = true;
}
}
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index c56421cca..a57fb9b84 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -18,7 +18,7 @@ Paraformer::Paraformer()
}
// offline
-void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
LoadConfigFromYaml(am_config.c_str());
// knf options
fbank_opts_.frame_opts.dither = 0;
@@ -65,13 +65,13 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
m_szInputNames.push_back(item.c_str());
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- vocab = new Vocab(am_config.c_str());
- phone_set_ = new PhoneSet(am_config.c_str());
+ vocab = new Vocab(token_file.c_str());
+ phone_set_ = new PhoneSet(token_file.c_str());
LoadCmvn(am_cmvn.c_str());
}
// online
-void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
LoadOnlineConfigFromYaml(am_config.c_str());
// knf options
@@ -143,15 +143,15 @@ void Paraformer::InitAsr(const std::string &en_model, const std::string &de_mode
for (auto& item : de_strOutputNames)
de_szOutputNames_.push_back(item.c_str());
- vocab = new Vocab(am_config.c_str());
- phone_set_ = new PhoneSet(am_config.c_str());
+ vocab = new Vocab(token_file.c_str());
+ phone_set_ = new PhoneSet(token_file.c_str());
LoadCmvn(am_cmvn.c_str());
}
// 2pass
-void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
// online
- InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
+ InitAsr(en_model, de_model, am_cmvn, am_config, token_file, thread_num);
// offline
try {
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 5bb9477bf..417c2d7b8 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -42,11 +42,11 @@ namespace funasr {
public:
Paraformer();
~Paraformer();
- void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// online
- void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// 2pass
- void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
void InitHwCompiler(const std::string &hw_model, int thread_num);
void InitSegDict(const std::string &seg_dict_model);
std::vector> CompileHotwordEmbedding(std::string &hotwords);
diff --git a/runtime/onnxruntime/src/phone-set.cpp b/runtime/onnxruntime/src/phone-set.cpp
index 167fa010a..60eb1019d 100644
--- a/runtime/onnxruntime/src/phone-set.cpp
+++ b/runtime/onnxruntime/src/phone-set.cpp
@@ -13,7 +13,7 @@ using namespace std;
namespace funasr {
PhoneSet::PhoneSet(const char *filename) {
ifstream in(filename);
- LoadPhoneSetFromYaml(filename);
+ LoadPhoneSetFromJson(filename);
}
PhoneSet::~PhoneSet()
{
@@ -35,6 +35,25 @@ void PhoneSet::LoadPhoneSetFromYaml(const char* filename) {
}
}
+void PhoneSet::LoadPhoneSetFromJson(const char* filename) {
+ nlohmann::json json_array;
+ std::ifstream file(filename);
+ if (file.is_open()) {
+ file >> json_array;
+ file.close();
+ } else {
+ LOG(INFO) << "Error loading token file, token file error or not exist.";
+ exit(-1);
+ }
+
+ int id = 0;
+ for (const auto& element : json_array) {
+ phone_.push_back(element);
+ phn2Id_.emplace(element, id);
+ id++;
+ }
+}
+
int PhoneSet::Size() const {
return phone_.size();
}
diff --git a/runtime/onnxruntime/src/phone-set.h b/runtime/onnxruntime/src/phone-set.h
index 972910408..cb1a2a755 100644
--- a/runtime/onnxruntime/src/phone-set.h
+++ b/runtime/onnxruntime/src/phone-set.h
@@ -5,6 +5,7 @@
#include
#include
#include
+#include "nlohmann/json.hpp"
#define UNIT_BEG_SIL_SYMBOL ""
#define UNIT_END_SIL_SYMBOL ""
#define UNIT_BLK_SYMBOL ""
@@ -28,6 +29,7 @@ class PhoneSet {
vector phone_;
unordered_map phn2Id_;
void LoadPhoneSetFromYaml(const char* filename);
+ void LoadPhoneSetFromJson(const char* filename);
};
} // namespace funasr
diff --git a/runtime/onnxruntime/src/punc-model.cpp b/runtime/onnxruntime/src/punc-model.cpp
index 54b8d6a46..9af03db29 100644
--- a/runtime/onnxruntime/src/punc-model.cpp
+++ b/runtime/onnxruntime/src/punc-model.cpp
@@ -14,14 +14,16 @@ PuncModel *CreatePuncModel(std::map& model_path, int t
}
string punc_model_path;
string punc_config_path;
+ string token_file;
punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME);
+ token_file = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
- mm->InitPunc(punc_model_path, punc_config_path, thread_num);
+ mm->InitPunc(punc_model_path, punc_config_path, token_file, thread_num);
return mm;
}
diff --git a/runtime/onnxruntime/src/tokenizer.cpp b/runtime/onnxruntime/src/tokenizer.cpp
index 761828269..06d64d86f 100644
--- a/runtime/onnxruntime/src/tokenizer.cpp
+++ b/runtime/onnxruntime/src/tokenizer.cpp
@@ -127,6 +127,61 @@ bool CTokenizer::OpenYaml(const char* sz_yamlfile)
return m_ready;
}
+bool CTokenizer::OpenYaml(const char* sz_yamlfile, const char* token_file)
+{
+ YAML::Node m_Config;
+ try{
+ m_Config = YAML::LoadFile(sz_yamlfile);
+ }catch(exception const &e){
+ LOG(INFO) << "Error loading file, yaml file error or not exist.";
+ exit(-1);
+ }
+
+ try
+ {
+ YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
+ if (conf_seg_jieba.IsDefined()){
+ seg_jieba = conf_seg_jieba.as();
+ }
+
+ auto Puncs = m_Config["model_conf"]["punc_list"];
+ if (Puncs.IsSequence())
+ {
+ for (size_t i = 0; i < Puncs.size(); ++i)
+ {
+ if (Puncs[i].IsScalar())
+ {
+ m_id2punc.push_back(Puncs[i].as());
+ m_punc2id.insert(make_pair(Puncs[i].as(), i));
+ }
+ }
+ }
+
+ nlohmann::json json_array;
+ std::ifstream file(token_file);
+ if (file.is_open()) {
+ file >> json_array;
+ file.close();
+ } else {
+ LOG(INFO) << "Error loading token file, token file error or not exist.";
+ return false;
+ }
+
+ int i = 0;
+ for (const auto& element : json_array) {
+ m_id2token.push_back(element);
+ m_token2id[element] = i;
+ i++;
+ }
+ }
+ catch (YAML::BadFile& e) {
+ LOG(ERROR) << "Read error!";
+ return false;
+ }
+ m_ready = true;
+ return m_ready;
+}
+
vector CTokenizer::Id2String(vector input)
{
vector result;
diff --git a/runtime/onnxruntime/src/tokenizer.h b/runtime/onnxruntime/src/tokenizer.h
index 166061bd4..81aea7ed3 100644
--- a/runtime/onnxruntime/src/tokenizer.h
+++ b/runtime/onnxruntime/src/tokenizer.h
@@ -8,6 +8,7 @@
#include "cppjieba/DictTrie.hpp"
#include "cppjieba/HMMModel.hpp"
#include "cppjieba/Jieba.hpp"
+#include "nlohmann/json.hpp"
namespace funasr {
class CTokenizer {
@@ -27,6 +28,7 @@ public:
CTokenizer();
~CTokenizer();
bool OpenYaml(const char* sz_yamlfile);
+ bool OpenYaml(const char* sz_yamlfile, const char* token_file);
void ReadYaml(const YAML::Node& node);
vector Id2String(vector input);
vector String2Ids(vector input);
diff --git a/runtime/onnxruntime/src/tpass-stream.cpp b/runtime/onnxruntime/src/tpass-stream.cpp
index b723e0fa1..7681a4db0 100644
--- a/runtime/onnxruntime/src/tpass-stream.cpp
+++ b/runtime/onnxruntime/src/tpass-stream.cpp
@@ -35,6 +35,7 @@ TpassStream::TpassStream(std::map& model_path, int thr
string de_model_path;
string am_cmvn_path;
string am_config_path;
+ string token_path;
string hw_compile_model_path;
string seg_dict_path;
@@ -60,8 +61,9 @@ TpassStream::TpassStream(std::map& model_path, int thr
}
am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
+ token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
- asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+ asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
}else{
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
exit(-1);
@@ -85,20 +87,23 @@ TpassStream::TpassStream(std::map& model_path, int thr
if(model_path.find(PUNC_DIR) != model_path.end()){
string punc_model_path;
string punc_config_path;
+ string token_path;
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+ token_path = PathAppend(model_path.at(PUNC_DIR), TOKEN_PATH);
if (access(punc_model_path.c_str(), F_OK) != 0 ||
- access(punc_config_path.c_str(), F_OK) != 0 )
+ access(punc_config_path.c_str(), F_OK) != 0 ||
+ access(token_path.c_str(), F_OK) != 0)
{
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
}else{
punc_online_handle = make_unique();
- punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ punc_online_handle->InitPunc(punc_model_path, punc_config_path, token_path, thread_num);
use_punc = true;
}
}
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index 6991376a9..1416dd314 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -14,7 +14,7 @@ namespace funasr {
Vocab::Vocab(const char *filename)
{
ifstream in(filename);
- LoadVocabFromYaml(filename);
+ LoadVocabFromJson(filename);
}
Vocab::Vocab(const char *filename, const char *lex_file)
{
@@ -43,6 +43,25 @@ void Vocab::LoadVocabFromYaml(const char* filename){
}
}
+void Vocab::LoadVocabFromJson(const char* filename){
+ nlohmann::json json_array;
+ std::ifstream file(filename);
+ if (file.is_open()) {
+ file >> json_array;
+ file.close();
+ } else {
+ LOG(INFO) << "Error loading token file, token file error or not exist.";
+ exit(-1);
+ }
+
+ int i = 0;
+ for (const auto& element : json_array) {
+ vocab.push_back(element);
+ token_id[element] = i;
+ i++;
+ }
+}
+
void Vocab::LoadLex(const char* filename){
std::ifstream file(filename);
std::string line;
diff --git a/runtime/onnxruntime/src/vocab.h b/runtime/onnxruntime/src/vocab.h
index 19e364867..36fabf446 100644
--- a/runtime/onnxruntime/src/vocab.h
+++ b/runtime/onnxruntime/src/vocab.h
@@ -6,6 +6,7 @@
#include
#include
#include