mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Resolve merge conflict
This commit is contained in:
commit
8912e0696a
@ -27,7 +27,12 @@
|
||||
|
||||
|
||||
<a name="whats-new"></a>
|
||||
## What's new:
|
||||
## What's new:
|
||||
- 2024/01/09: The Funasr SDK for Windows version 2.0 has been released, featuring support for The offline file transcription service (CPU) of Mandarin 4.1, The offline file transcription service (CPU) of English 1.2, The real-time transcription service (CPU) of Mandarin 1.6. For more details, please refer to the official documentation or release notes([FunASR-Runtime-Windows](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
|
||||
- 2024/01/03: File Transcription Service 4.0 released, Added support for 8k models, optimized timestamp mismatch issues and added sentence-level timestamps, improved the effectiveness of English word FST hotwords, supported automated configuration of thread parameters, and fixed known crash issues as well as memory leak problems, refer to ([docs](runtime/readme.md#file-transcription-service-mandarin-cpu)).
|
||||
- 2024/01/03: Real-time Transcription Service 1.6 released,The 2pass-offline mode supports Ngram language model decoding and WFST hotwords, while also addressing known crash issues and memory leak problems, ([docs](runtime/readme.md#the-real-time-transcription-service-mandarin-cpu))
|
||||
- 2024/01/03: Fixed known crash issues as well as memory leak problems, ([docs](runtime/readme.md#file-transcription-service-english-cpu)).
|
||||
- 2023/12/04: The Funasr SDK for Windows version 1.0 has been released, featuring support for The offline file transcription service (CPU) of Mandarin, The offline file transcription service (CPU) of English, The real-time transcription service (CPU) of Mandarin. For more details, please refer to the official documentation or release notes([FunASR-Runtime-Windows](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
|
||||
- 2023/11/08: The offline file transcription service 3.0 (CPU) of Mandarin has been released, adding punctuation large model, Ngram language model, and wfst hot words. For detailed information, please refer to [docs](runtime#file-transcription-service-mandarin-cpu).
|
||||
- 2023/10/17: The offline file transcription service (CPU) of English has been released. For more details, please refer to ([docs](runtime#file-transcription-service-english-cpu)).
|
||||
- 2023/10/13: [SlideSpeech](https://slidespeech.github.io/): A large scale multi-modal audio-visual corpus with a significant amount of real-time synchronized slides.
|
||||
@ -50,7 +55,7 @@ FunASR has open-sourced a large number of pre-trained models on industrial data.
|
||||
(Note: 🤗 represents the Huggingface model zoo link, ⭐ represents the ModelScope model zoo link)
|
||||
|
||||
|
||||
| Model Name | Task Details | Training Date | Parameters |
|
||||
| Model Name | Task Details | Training Data | Parameters |
|
||||
|:------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------:|:--------------------------------:|:----------:|
|
||||
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗]() ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M |
|
||||
| paraformer-zh-spk <br> ( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn/summary) [🤗]() ) | speech recognition with speaker diarization, with timestamps, non-streaming | 60000 hours, Mandarin | 220M |
|
||||
|
||||
@ -31,6 +31,11 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
|
||||
|
||||
<a name="最新动态"></a>
|
||||
## 最新动态
|
||||
- 2024/01/09: funasr社区软件包windows 2.0版本发布,支持软件包中文离线文件转写4.1、英文离线文件转写1.2、中文实时听写服务1.6的最新功能,详细信息参阅([FunASR社区软件包windows版本](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
|
||||
- 2024/01/03: 中文离线文件转写服务 4.0 发布,新增支持8k模型、优化时间戳不匹配问题及增加句子级别时间戳、优化英文单词fst热词效果、支持自动化配置线程参数,同时修复已知的crash问题及内存泄漏问题,详细信息参阅([一键部署文档](runtime/readme_cn.md#中文离线文件转写服务cpu版本))
|
||||
- 2024/01/03: 中文实时语音听写服务 1.6 发布,2pass-offline模式支持Ngram语言模型解码、wfst热词,同时修复已知的crash问题及内存泄漏问题,详细信息参阅([一键部署文档](runtime/readme_cn.md#中文实时语音听写服务cpu版本))
|
||||
- 2024/01/03: 英文离线文件转写服务 1.2 发布,修复已知的crash问题及内存泄漏问题,详细信息参阅([一键部署文档](runtime/readme_cn.md#英文离线文件转写服务cpu版本))
|
||||
- 2023/12/04: funasr社区软件包windows 1.0版本发布,支持中文离线文件转写、英文离线文件转写、中文实时听写服务,详细信息参阅([FunASR社区软件包windows版本](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
|
||||
- 2023/11/08:中文离线文件转写服务3.0 CPU版本发布,新增标点大模型、Ngram语言模型与wfst热词,详细信息参阅([一键部署文档](runtime/readme_cn.md#中文离线文件转写服务cpu版本))
|
||||
- 2023/10/17: 英文离线文件转写服务一键部署的CPU版本发布,详细信息参阅([一键部署文档](runtime/readme_cn.md#英文离线文件转写服务cpu版本))
|
||||
- 2023/10/13: [SlideSpeech](https://slidespeech.github.io/): 一个大规模的多模态音视频语料库,主要是在线会议或者在线课程场景,包含了大量与发言人讲话实时同步的幻灯片。
|
||||
|
||||
@ -37,7 +37,7 @@ sudo systemctl start docker
|
||||
### Image Hub
|
||||
|
||||
#### CPU
|
||||
`registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0`
|
||||
`registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1`
|
||||
|
||||
#### GPU
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ sudo systemctl start docker
|
||||
### 镜像仓库
|
||||
|
||||
#### CPU
|
||||
`registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0`
|
||||
`registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1`
|
||||
|
||||
#### GPU
|
||||
|
||||
|
||||
@ -4,15 +4,16 @@ import numpy as np
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
||||
class BatchSampler(torch.utils.data.BatchSampler):
|
||||
|
||||
def __init__(self, dataset,
|
||||
batch_type: str="example",
|
||||
batch_size: int=100,
|
||||
buffer_size: int=30,
|
||||
drop_last: bool=False,
|
||||
shuffle: bool=True,
|
||||
batch_type: str = "example",
|
||||
batch_size: int = 100,
|
||||
buffer_size: int = 30,
|
||||
drop_last: bool = False,
|
||||
shuffle: bool = True,
|
||||
**kwargs):
|
||||
|
||||
self.drop_last = drop_last
|
||||
@ -25,24 +26,23 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
||||
self.max_token_length = kwargs.get("max_token_length", 5000)
|
||||
self.shuffle_idx = np.arange(self.total_samples)
|
||||
self.shuffle = shuffle
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
np.random.seed(epoch)
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.shuffle_idx)
|
||||
|
||||
|
||||
batch = []
|
||||
max_token = 0
|
||||
num_sample = 0
|
||||
|
||||
iter_num = (self.total_samples-1) // self.buffer_size + 1
|
||||
|
||||
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
||||
# print("iter_num: ", iter_num)
|
||||
for iter in range(self.pre_idx + 1, iter_num):
|
||||
datalen_with_index = []
|
||||
@ -50,12 +50,12 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
||||
idx = iter * self.buffer_size + i
|
||||
if idx >= self.total_samples:
|
||||
continue
|
||||
|
||||
|
||||
idx_map = self.shuffle_idx[idx]
|
||||
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
||||
sample_len_cur = self.dataset.get_source_len(idx_map) + \
|
||||
self.dataset.get_target_len(idx_map)
|
||||
|
||||
|
||||
datalen_with_index.append([idx, sample_len_cur])
|
||||
|
||||
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
||||
@ -63,7 +63,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
||||
idx, sample_len_cur_raw = item
|
||||
if sample_len_cur_raw > self.max_token_length:
|
||||
continue
|
||||
|
||||
|
||||
max_token_cur = max(max_token, sample_len_cur_raw)
|
||||
max_token_padding = 1 + num_sample
|
||||
if self.batch_type == 'length':
|
||||
@ -77,5 +77,4 @@ class BatchSampler(torch.utils.data.BatchSampler):
|
||||
batch = [idx]
|
||||
max_token = sample_len_cur_raw
|
||||
num_sample = 1
|
||||
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -113,12 +113,12 @@ class RWKVEncoder(AbsEncoder):
|
||||
x = self.embed_norm(x)
|
||||
olens = mask.eq(0).sum(1)
|
||||
|
||||
# for training
|
||||
# for block in self.rwkv_blocks:
|
||||
# x, _ = block(x)
|
||||
|
||||
# for streaming inference
|
||||
x = self.rwkv_infer(x)
|
||||
if self.training:
|
||||
for block in self.rwkv_blocks:
|
||||
x, _ = block(x)
|
||||
else:
|
||||
x = self.rwkv_infer(x)
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
if self.time_reduction_factor > 1:
|
||||
|
||||
@ -443,7 +443,10 @@ class UniASR(FunASRModel):
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
<<<<<<< HEAD:funasr/models/uniasr/e2e_uni_asr.py
|
||||
|
||||
=======
|
||||
>>>>>>> main:funasr/models/e2e_uni_asr.py
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
@ -47,11 +47,11 @@ Use the following command to pull and launch the FunASR software package Docker
|
||||
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10096:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
```
|
||||
|
||||
###### Server Start
|
||||
@ -93,11 +93,11 @@ Use the following command to pull and launch the FunASR software package Docker
|
||||
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10095:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
```
|
||||
|
||||
###### Server Start
|
||||
|
||||
@ -48,11 +48,11 @@ sudo bash install_docker.sh
|
||||
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10096:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
```
|
||||
|
||||
###### 服务端启动
|
||||
@ -92,11 +92,11 @@ python3 funasr_wss_client.py --host "127.0.0.1" --port 10096 --mode 2pass
|
||||
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10095:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
```
|
||||
|
||||
###### 服务端启动
|
||||
|
||||
@ -20,4 +20,4 @@ scheduler_classes = dict(
|
||||
cycliclr=torch.optim.lr_scheduler.CyclicLR,
|
||||
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
|
||||
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
||||
)
|
||||
)
|
||||
|
||||
@ -11,89 +11,90 @@ import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
class AbsTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
class AbsTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
|
||||
unk_symbol: str = "<unk>",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if token_list is not None:
|
||||
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with open(token_list, 'r', encoding='utf-8') as f:
|
||||
self.token_list = json.load(f)
|
||||
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.text2tokens(text)
|
||||
text_ints = self.tokens2ids(tokens)
|
||||
|
||||
return text_ints
|
||||
|
||||
def decode(self, text_ints):
|
||||
token = self.ids2tokens(text_ints)
|
||||
text = self.tokens2text(token)
|
||||
return text
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if token_list is not None:
|
||||
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with token_list.open("r", encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
self.token_list.append(line)
|
||||
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
|
||||
token_list = Path(token_list)
|
||||
self.token_list_repr = str(token_list)
|
||||
self.token_list: List[str] = []
|
||||
|
||||
with open(token_list, 'r', encoding='utf-8') as f:
|
||||
self.token_list = json.load(f)
|
||||
|
||||
|
||||
else:
|
||||
self.token_list: List[str] = list(token_list)
|
||||
self.token_list_repr = ""
|
||||
for i, t in enumerate(self.token_list):
|
||||
if i == 3:
|
||||
break
|
||||
self.token_list_repr += f"{t}, "
|
||||
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
||||
|
||||
self.token2id: Dict[str, int] = {}
|
||||
for i, t in enumerate(self.token_list):
|
||||
if t in self.token2id:
|
||||
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
||||
self.token2id[t] = i
|
||||
|
||||
self.unk_symbol = unk_symbol
|
||||
if self.unk_symbol not in self.token2id:
|
||||
raise RuntimeError(
|
||||
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
||||
)
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.text2tokens(text)
|
||||
text_ints = self.tokens2ids(tokens)
|
||||
|
||||
return text_ints
|
||||
|
||||
def decode(self, text_ints):
|
||||
token = self.ids2tokens(text_ints)
|
||||
text = self.tokens2text(token)
|
||||
return text
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
@abstractmethod
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
raise NotImplementedError
|
||||
@ -1,17 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Union
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.char_tokenizer import CharTokenizer
|
||||
@ -28,8 +18,7 @@ def build_tokenizer(
|
||||
space_symbol: str = "<space>",
|
||||
delimiter: str = None,
|
||||
g2p_type: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> AbsTokenizer:
|
||||
"""A helper function to instantiate Tokenizer"""
|
||||
if token_type == "bpe":
|
||||
if bpemodel is None:
|
||||
@ -39,7 +28,7 @@ def build_tokenizer(
|
||||
raise RuntimeError(
|
||||
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
|
||||
)
|
||||
return SentencepiecesTokenizer(bpemodel, **kwargs)
|
||||
return SentencepiecesTokenizer(bpemodel)
|
||||
|
||||
elif token_type == "word":
|
||||
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
|
||||
@ -49,14 +38,13 @@ def build_tokenizer(
|
||||
remove_non_linguistic_symbols=True,
|
||||
)
|
||||
else:
|
||||
return WordTokenizer(delimiter=delimiter, **kwargs)
|
||||
return WordTokenizer(delimiter=delimiter)
|
||||
|
||||
elif token_type == "char":
|
||||
return CharTokenizer(
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
space_symbol=space_symbol,
|
||||
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif token_type == "phn":
|
||||
@ -65,7 +53,6 @@ def build_tokenizer(
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
space_symbol=space_symbol,
|
||||
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@ -59,4 +59,4 @@ class CharTokenizer(BaseTokenizer):
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
tokens = [t if t != self.space_symbol else " " for t in tokens]
|
||||
return "".join(tokens)
|
||||
return "".join(tokens)
|
||||
@ -1,75 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Union
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
from funasr.tokenizer.char_tokenizer import CharTokenizer
|
||||
from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
|
||||
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||
from funasr.tokenizer.word_tokenizer import WordTokenizer
|
||||
|
||||
def build_tokenizer(
|
||||
token_type: str,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
space_symbol: str = "<space>",
|
||||
delimiter: str = None,
|
||||
g2p_type: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""A helper function to instantiate Tokenizer"""
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if token_type == "bpe":
|
||||
if bpemodel is None:
|
||||
raise ValueError('bpemodel is required if token_type = "bpe"')
|
||||
|
||||
if remove_non_linguistic_symbols:
|
||||
raise RuntimeError(
|
||||
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
|
||||
)
|
||||
return SentencepiecesTokenizer(bpemodel, **kwargs)
|
||||
|
||||
elif token_type == "word":
|
||||
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
|
||||
return WordTokenizer(
|
||||
delimiter=delimiter,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
remove_non_linguistic_symbols=True,
|
||||
)
|
||||
else:
|
||||
return WordTokenizer(delimiter=delimiter, **kwargs)
|
||||
|
||||
elif token_type == "char":
|
||||
return CharTokenizer(
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
space_symbol=space_symbol,
|
||||
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif token_type == "phn":
|
||||
return PhonemeTokenizer(
|
||||
g2p_type=g2p_type,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
space_symbol=space_symbol,
|
||||
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
|
||||
)
|
||||
@ -363,7 +363,6 @@ class PhonemeTokenizer(AbsTokenizer):
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
space_symbol: str = "<space>",
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if g2p_type is None:
|
||||
self.g2p = split_by_space
|
||||
|
||||
@ -9,7 +9,7 @@ from funasr.tokenizer.abs_tokenizer import AbsTokenizer
|
||||
|
||||
|
||||
class SentencepiecesTokenizer(AbsTokenizer):
|
||||
def __init__(self, model: Union[Path, str], **kwargs):
|
||||
def __init__(self, model: Union[Path, str]):
|
||||
self.model = str(model)
|
||||
# NOTE(kamo):
|
||||
# Don't build SentencePieceProcessor in __init__()
|
||||
|
||||
@ -14,7 +14,6 @@ class WordTokenizer(AbsTokenizer):
|
||||
delimiter: str = None,
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.delimiter = delimiter
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ This document serves as a development guide for the FunASR offline file transcri
|
||||
|
||||
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|
||||
|------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|--------------|
|
||||
| 2024.01.08 | optimized format sentence-level timestamps | funasr-runtime-sdk-cpu-0.4.1 | 0250f8ef981b |
|
||||
| 2024.01.03 | Added support for 8k models, optimized timestamp mismatch issues and added sentence-level timestamps, improved the effectiveness of English word FST hotwords, supported automated configuration of thread parameters, and fixed known crash issues as well as memory leak problems. | funasr-runtime-sdk-cpu-0.4.0 | c4483ee08f04 |
|
||||
| 2023.11.08 | supporting punc-large model, Ngram model, fst hotwords, server-side loading of hotwords, adaptation to runtime structure changes | funasr-runtime-sdk-cpu-0.3.0 | caa64bddbb43 |
|
||||
| 2023.09.19 | supporting ITN model | funasr-runtime-sdk-cpu-0.2.2 | 2c5286be13e9 |
|
||||
| 2023.08.22 | integrated ffmpeg to support various audio and video inputs, supporting nn-hotword model and timestamp model | funasr-runtime-sdk-cpu-0.2.0 | 1ad3d19e0707 |
|
||||
@ -30,9 +32,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
|
||||
### Pulling and launching images
|
||||
Use the following command to pull and launch the Docker image for the FunASR runtime-SDK:
|
||||
```shell
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
|
||||
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
```
|
||||
|
||||
Introduction to command parameters:
|
||||
@ -84,11 +86,8 @@ Introduction to run_server.sh parameters:
|
||||
```text
|
||||
--download-model-dir: Model download address, download models from Modelscope by setting the model ID.
|
||||
--model-dir: modelscope model ID or local model path.
|
||||
--quantize: True for quantized ASR model, False for non-quantized ASR model. Default is True.
|
||||
--vad-dir: modelscope model ID or local model path.
|
||||
--vad-quant: True for quantized VAD model, False for non-quantized VAD model. Default is True.
|
||||
--punc-dir: modelscope model ID or local model path.
|
||||
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
|
||||
--itn-dir modelscope model ID or local model path.
|
||||
--port: Port number that the server listens on. Default is 10095.
|
||||
--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests.
|
||||
|
||||
@ -6,6 +6,7 @@ This document serves as a development guide for the FunASR offline file transcri
|
||||
|
||||
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|
||||
|------------|-----------------------------------------|---------------------------------|--------------|
|
||||
| 2024.01.03 | fixed known crash issues as well as memory leak problems | funasr-runtime-sdk-en-cpu-0.1.2 | 0cdd9f4a4bb5 |
|
||||
| 2023.11.08 | Adaptation to runtime structure changes | funasr-runtime-sdk-en-cpu-0.1.1 | 27017f70f72a |
|
||||
| 2023.10.16 | 1.0 released | funasr-runtime-sdk-en-cpu-0.1.0 | e0de03eb0163 |
|
||||
|
||||
@ -21,9 +22,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
|
||||
### Pulling and launching images
|
||||
Use the following command to pull and launch the Docker image for the FunASR runtime-SDK:
|
||||
```shell
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.1
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.2
|
||||
|
||||
sudo docker run -p 10097:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.1
|
||||
sudo docker run -p 10097:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.2
|
||||
```
|
||||
Introduction to command parameters:
|
||||
```text
|
||||
@ -63,11 +64,8 @@ Introduction to run_server.sh parameters:
|
||||
```text
|
||||
--download-model-dir: Model download address, download models from Modelscope by setting the model ID.
|
||||
--model-dir: modelscope model ID or local model path.
|
||||
--quantize: True for quantized ASR model, False for non-quantized ASR model. Default is True.
|
||||
--vad-dir: modelscope model ID or local model path.
|
||||
--vad-quant: True for quantized VAD model, False for non-quantized VAD model. Default is True.
|
||||
--punc-dir: modelscope model ID or local model path.
|
||||
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
|
||||
--itn-dir modelscope model ID or local model path.
|
||||
--port: Port number that the server listens on. Default is 10095.
|
||||
--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests.
|
||||
|
||||
@ -6,6 +6,7 @@ FunASR提供可一键本地或者云端服务器部署的英文离线文件转
|
||||
|
||||
| 时间 | 详情 | 镜像版本 | 镜像ID |
|
||||
|------------|---------------|---------------------------------|--------------|
|
||||
| 2024.01.03 | 修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-en-cpu-0.1.2 | 0cdd9f4a4bb5 |
|
||||
| 2023.11.08 | runtime结构变化适配 | funasr-runtime-sdk-en-cpu-0.1.1 | 27017f70f72a |
|
||||
| 2023.10.16 | 1.0 发布 | funasr-runtime-sdk-en-cpu-0.1.0 | e0de03eb0163 |
|
||||
|
||||
@ -36,11 +37,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
|
||||
通过下述命令拉取并启动FunASR runtime-SDK的docker镜像:
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.1
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.2
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10097:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.1
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-en-cpu-0.1.2
|
||||
```
|
||||
|
||||
### 服务端启动
|
||||
@ -148,11 +149,8 @@ nohup bash run_server.sh \
|
||||
```text
|
||||
--download-model-dir 模型下载地址,通过设置model ID从Modelscope下载模型
|
||||
--model-dir modelscope model ID 或者 本地模型路径
|
||||
--quantize True为量化ASR模型,False为非量化ASR模型,默认是True
|
||||
--vad-dir modelscope model ID 或者 本地模型路径
|
||||
--vad-quant True为量化VAD模型,False为非量化VAD模型,默认是True
|
||||
--punc-dir modelscope model ID 或者 本地模型路径
|
||||
--punc-quant True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--itn-dir modelscope model ID 或者 本地模型路径
|
||||
--port 服务端监听的端口号,默认为 10095
|
||||
--decoder-thread-num 服务端线程池个数(支持的最大并发路数),
|
||||
|
||||
@ -10,6 +10,8 @@ FunASR离线文件转写软件包,提供了一款功能强大的语音离线
|
||||
|
||||
| 时间 | 详情 | 镜像版本 | 镜像ID |
|
||||
|------------|---------------------------------------------------|------------------------------|--------------|
|
||||
| 2024.01.08 | 优化句子级时间戳json格式 | funasr-runtime-sdk-cpu-0.4.1 | 0250f8ef981b |
|
||||
| 2024.01.03 | 新增支持8k模型、优化时间戳不匹配问题及增加句子级别时间戳、优化英文单词fst热词效果、支持自动化配置线程参数,同时修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-cpu-0.4.0 | c4483ee08f04 |
|
||||
| 2023.11.08 | 支持标点大模型、支持Ngram模型、支持fst热词、支持服务端加载热词、runtime结构变化适配 | funasr-runtime-sdk-cpu-0.3.0 | caa64bddbb43 |
|
||||
| 2023.09.19 | 支持ITN模型 | funasr-runtime-sdk-cpu-0.2.2 | 2c5286be13e9 |
|
||||
| 2023.08.22 | 集成ffmpeg支持多种音视频输入、支持热词模型、支持时间戳模型 | funasr-runtime-sdk-cpu-0.2.0 | 1ad3d19e0707 |
|
||||
@ -44,11 +46,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
|
||||
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10095:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.3.0
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.4.1
|
||||
```
|
||||
|
||||
### 服务端启动
|
||||
@ -70,10 +72,23 @@ nohup bash run_server.sh \
|
||||
# damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx(时间戳)
|
||||
# damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404-onnx(nn热词)
|
||||
# 如果您想在服务端加载热词,请在宿主机文件./funasr-runtime-resources/models/hotwords.txt配置热词(docker映射地址为/workspace/models/hotwords.txt):
|
||||
# 每行一个热词,格式(热词 权重):阿里巴巴 20
|
||||
# 每行一个热词,格式(热词 权重):阿里巴巴 20(注:热词理论上无限制,但为了兼顾性能和效果,建议热词长度不超过10,个数不超过1k,权重1~100)
|
||||
```
|
||||
如果您想定制ngram,参考文档([如何训练LM](./lm_train_tutorial.md))
|
||||
|
||||
如果您想部署8k的模型,请使用如下命令启动服务:
|
||||
```shell
|
||||
cd FunASR/runtime
|
||||
nohup bash run_server.sh \
|
||||
--download-model-dir /workspace/models \
|
||||
--vad-dir damo/speech_fsmn_vad_zh-cn-8k-common \
|
||||
--model-dir damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1 \
|
||||
--punc-dir damo/punc_ct-transformer_cn-en-common-vocab471067-large-onnx \
|
||||
--lm-dir damo/speech_ngram_lm_zh-cn-ai-wesp-fst-token8358 \
|
||||
--itn-dir thuduj12/fst_itn_zh \
|
||||
--hotword /workspace/models/hotwords.txt > log.out 2>&1 &
|
||||
```
|
||||
|
||||
服务端详细参数介绍可参考[服务端用法详解](#服务端用法详解)
|
||||
|
||||
### 客户端测试与使用
|
||||
@ -165,11 +180,8 @@ nohup bash run_server.sh \
|
||||
```text
|
||||
--download-model-dir 模型下载地址,通过设置model ID从Modelscope下载模型
|
||||
--model-dir modelscope model ID 或者 本地模型路径
|
||||
--quantize True为量化ASR模型,False为非量化ASR模型,默认是True
|
||||
--vad-dir modelscope model ID 或者 本地模型路径
|
||||
--vad-quant True为量化VAD模型,False为非量化VAD模型,默认是True
|
||||
--punc-dir modelscope model ID 或者 本地模型路径
|
||||
--punc-quant True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--lm-dir modelscope model ID 或者 本地模型路径
|
||||
--itn-dir modelscope model ID 或者 本地模型路径
|
||||
--port 服务端监听的端口号,默认为 10095
|
||||
|
||||
@ -8,6 +8,7 @@ FunASR Real-time Speech Recognition Software Package integrates real-time versio
|
||||
|
||||
| TIME | INFO | IMAGE VERSION | IMAGE ID |
|
||||
|------------|-------------------------------------------------------------------------------------|-------------------------------------|--------------|
|
||||
| 2024.01.03 | The 2pass-offline mode supports Ngram language model decoding and WFST hotwords, while also addressing known crash issues and memory leak problems | funasr-runtime-sdk-online-cpu-0.1.6 | f99925110d27 |
|
||||
| 2023.11.09 | fix bug: without online results | funasr-runtime-sdk-online-cpu-0.1.5 | b16584b6d38b |
|
||||
| 2023.11.08 | supporting server-side loading of hotwords, adaptation to runtime structure changes | funasr-runtime-sdk-online-cpu-0.1.4 | 691974017c38 |
|
||||
| 2023.09.19 | supporting hotwords, timestamps, and ITN model in 2pass mode | funasr-runtime-sdk-online-cpu-0.1.2 | 7222c5319bcf |
|
||||
@ -26,9 +27,9 @@ If you do not have Docker installed, please refer to [Docker Installation](https
|
||||
### Pull Docker Image
|
||||
Use the following command to pull and start the FunASR software package docker image:
|
||||
```shell
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10096:10095 -it --privileged=true -v $PWD/funasr-runtime-resources/models:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
sudo docker run -p 10096:10095 -it --privileged=true -v $PWD/funasr-runtime-resources/models:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
```
|
||||
|
||||
### Launching the Server
|
||||
@ -42,6 +43,7 @@ nohup bash run_server_2pass.sh \
|
||||
--model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \
|
||||
--punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
|
||||
--lm-dir damo/speech_ngram_lm_zh-cn-ai-wesp-fst \
|
||||
--itn-dir thuduj12/fst_itn_zh > log.out 2>&1 &
|
||||
|
||||
# If you want to close ssl,please add:--certfile 0
|
||||
@ -84,6 +86,7 @@ nohup bash run_server_2pass.sh \
|
||||
--online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \
|
||||
--vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
|
||||
--punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
|
||||
--lm-dir damo/speech_ngram_lm_zh-cn-ai-wesp-fst \
|
||||
--itn-dir thuduj12/fst_itn_zh \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key \
|
||||
@ -101,11 +104,9 @@ nohup bash run_server_2pass.sh \
|
||||
--download-model-dir: Model download address, download models from Modelscope by setting the model ID.
|
||||
--model-dir: modelscope model ID or local model path.
|
||||
--online-model-dir modelscope model ID
|
||||
--quantize: True for quantized ASR model, False for non-quantized ASR model. Default is True.
|
||||
--vad-dir: modelscope model ID or local model path.
|
||||
--vad-quant: True for quantized VAD model, False for non-quantized VAD model. Default is True.
|
||||
--punc-dir: modelscope model ID or local model path.
|
||||
--punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
|
||||
--lm-dir modelscope model ID or local model path.
|
||||
--itn-dir modelscope model ID or local model path.
|
||||
--port: Port number that the server listens on. Default is 10095.
|
||||
--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests.
|
||||
|
||||
@ -12,6 +12,7 @@ FunASR实时语音听写软件包,集成了实时版本的语音端点检测
|
||||
|
||||
| 时间 | 详情 | 镜像版本 | 镜像ID |
|
||||
|:-----------|:----------------------------------|--------------------------------------|--------------|
|
||||
| 2024.01.03 | 2pass-offline模式支持Ngram语言模型解码、wfst热词,同时修复已知的crash问题及内存泄漏问题 | funasr-runtime-sdk-online-cpu-0.1.6 | f99925110d27 |
|
||||
| 2023.11.09 | 修复无实时结果问题 | funasr-runtime-sdk-online-cpu-0.1.5 | b16584b6d38b |
|
||||
| 2023.11.08 | 支持服务端加载热词(更新热词通信协议)、runtime结构变化适配 | funasr-runtime-sdk-online-cpu-0.1.4 | 691974017c38 |
|
||||
| 2023.09.19 | 2pass模式支持热词、时间戳、ITN模型 | funasr-runtime-sdk-online-cpu-0.1.2 | 7222c5319bcf |
|
||||
@ -35,11 +36,11 @@ docker安装失败请参考 [Docker Installation](https://alibaba-damo-academy.g
|
||||
|
||||
```shell
|
||||
sudo docker pull \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
mkdir -p ./funasr-runtime-resources/models
|
||||
sudo docker run -p 10096:10095 -it --privileged=true \
|
||||
-v $PWD/funasr-runtime-resources/models:/workspace/models \
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.5
|
||||
registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-online-cpu-0.1.6
|
||||
```
|
||||
|
||||
### 服务端启动
|
||||
@ -53,6 +54,7 @@ nohup bash run_server_2pass.sh \
|
||||
--model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \
|
||||
--punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
|
||||
--lm-dir damo/speech_ngram_lm_zh-cn-ai-wesp-fst \
|
||||
--itn-dir thuduj12/fst_itn_zh \
|
||||
--hotword /workspace/models/hotwords.txt > log.out 2>&1 &
|
||||
|
||||
@ -61,7 +63,7 @@ nohup bash run_server_2pass.sh \
|
||||
# damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx(时间戳)
|
||||
# damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404-onnx(nn热词)
|
||||
# 如果您想在服务端加载热词,请在宿主机文件./funasr-runtime-resources/models/hotwords.txt配置热词(docker映射地址为/workspace/models/hotwords.txt):
|
||||
# 每行一个热词,格式(热词 权重):阿里巴巴 20
|
||||
# 每行一个热词,格式(热词 权重):阿里巴巴 20(注:热词理论上无限制,但为了兼顾性能和效果,建议热词长度不超过10,个数不超过1k,权重1~100)
|
||||
```
|
||||
服务端详细参数介绍可参考[服务端用法详解](#服务端用法详解)
|
||||
### 客户端测试与使用
|
||||
@ -100,6 +102,7 @@ nohup bash run_server_2pass.sh \
|
||||
--online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \
|
||||
--vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
|
||||
--punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
|
||||
--lm-dir damo/speech_ngram_lm_zh-cn-ai-wesp-fst \
|
||||
--itn-dir thuduj12/fst_itn_zh \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key \
|
||||
@ -110,11 +113,9 @@ nohup bash run_server_2pass.sh \
|
||||
--download-model-dir 模型下载地址,通过设置model ID从Modelscope下载模型
|
||||
--model-dir modelscope model ID 或者 本地模型路径
|
||||
--online-model-dir modelscope model ID 或者 本地模型路径
|
||||
--quantize True为量化ASR模型,False为非量化ASR模型,默认是True
|
||||
--vad-dir modelscope model ID 或者 本地模型路径
|
||||
--vad-quant True为量化VAD模型,False为非量化VAD模型,默认是True
|
||||
--punc-dir modelscope model ID 或者 本地模型路径
|
||||
--punc-quant True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--lm-dir modelscope model ID 或者 本地模型路径
|
||||
--itn-dir modelscope model ID 或者 本地模型路径
|
||||
--port 服务端监听的端口号,默认为 10095
|
||||
--decoder-thread-num 服务端线程池个数(支持的最大并发路数),
|
||||
|
||||
@ -88,7 +88,9 @@ Command parameter description:
|
||||
--port specifies the deployment port number as 10095.
|
||||
--mode: `offline` indicates that the inference mode is one-sentence recognition; `online` indicates that the inference mode is real-time speech recognition; `2pass` indicates real-time speech recognition, and offline models are used for error correction at the end of each sentence.
|
||||
--chunk-size: indicates the latency configuration of the streaming model. [5,10,5] indicates that the current audio is 600ms, with a lookback of 300ms and a lookahead of 300ms.
|
||||
--record record is 1 means using record, fefault is 0
|
||||
--wav-path specifies the audio file to be transcribed, and supports file paths.
|
||||
--audio-fs the sample rate of the audio
|
||||
--threa-num sets the number of concurrent send threads, with a default value of 1.
|
||||
--is-ssl sets whether to enable SSL certificate verification, with a default value of 1 for enabling and 0 for disabling.
|
||||
--hotword: Hotword file path, one line for each hotword(e.g.:阿里巴巴 20)
|
||||
|
||||
@ -96,7 +96,9 @@ python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass
|
||||
--mode:`offline`表示推理模式为一句话识别;`online`表示推理模式为实时语音识别;`2pass`表示为实时语音识别,
|
||||
并且说话句尾采用离线模型进行纠错。
|
||||
--chunk-size:表示流式模型latency配置`[5,10,5]`,表示当前音频解码片段为600ms,并且回看300ms,右看300ms。
|
||||
--record 1表示使用麦克风作为输入,默认为0
|
||||
--wav-path 需要进行转写的音频文件,支持文件路径
|
||||
--audio-fs pcm音频采样率
|
||||
--thread-num 设置并发发送线程数,默认为1
|
||||
--is-ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭
|
||||
--hotword 热词文件,每行一个热词,格式(热词 权重):阿里巴巴 20
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
DOCKER:
|
||||
funasr-runtime-sdk-en-cpu-0.1.1
|
||||
funasr-runtime-sdk-en-cpu-0.1.2
|
||||
DEFAULT_ASR_MODEL:
|
||||
damo/speech_paraformer-large_asr_nat-en-16k-common-vocab10020-onnx
|
||||
DEFAULT_VAD_MODEL:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
DOCKER:
|
||||
funasr-runtime-sdk-cpu-0.4.0
|
||||
funasr-runtime-sdk-cpu-0.3.0
|
||||
funasr-runtime-sdk-cpu-0.2.2
|
||||
funasr-runtime-sdk-cpu-0.2.1
|
||||
DEFAULT_ASR_MODEL:
|
||||
damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
|
||||
damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
DOCKER:
|
||||
funasr-runtime-sdk-online-cpu-0.1.6
|
||||
funasr-runtime-sdk-online-cpu-0.1.5
|
||||
funasr-runtime-sdk-online-cpu-0.1.3
|
||||
funasr-runtime-sdk-online-cpu-0.1.2
|
||||
DEFAULT_ASR_MODEL:
|
||||
damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
|
||||
damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
```shell
|
||||
# 下载: 示例训练语料text、lexicon 和 am建模单元units.txt
|
||||
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/requirements/lm.tar.gz
|
||||
# 如果是匹配8k的am模型,使用 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/requirements/lm_8358.tar.gz
|
||||
tar -zxvf lm.tar.gz
|
||||
```
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ After sending the audio data, an end-of-audio flag needs to be sent (which needs
|
||||
#### Sending Recognition Results
|
||||
The message (serialized in JSON) is:
|
||||
```text
|
||||
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]"}
|
||||
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]", "stamp_sents":[]}
|
||||
```
|
||||
Parameter explanation:
|
||||
```text
|
||||
@ -45,6 +45,7 @@ Parameter explanation:
|
||||
`text`: the text output of speech recognition
|
||||
`is_final`: indicating the end of recognition
|
||||
`timestamp`:If AM is a timestamp model, it will return this field, indicating the timestamp, in the format of "[[100,200], [200,500]]"
|
||||
`stamp_sents`:If AM is a timestamp model, it will return this field, indicating the stamp_sents, in the format of [{"text_seg":"正 是 因 为","punc":",","start":"430","end":"1130","ts_list":[[430,670],[670,810],[810,1030],[1030,1130]]}]
|
||||
```
|
||||
|
||||
## Real-time Speech Recognition
|
||||
@ -84,7 +85,7 @@ After sending the audio data, an end-of-audio flag needs to be sent (which needs
|
||||
The message (serialized in JSON) is:
|
||||
|
||||
```text
|
||||
{"mode": "2pass-online", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]"}
|
||||
{"mode": "2pass-online", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]", "stamp_sents":[]}
|
||||
```
|
||||
Parameter explanation:
|
||||
```text
|
||||
@ -93,4 +94,5 @@ Parameter explanation:
|
||||
`text`: the text output of speech recognition
|
||||
`is_final`: indicating the end of recognition
|
||||
`timestamp`:If AM is a timestamp model, it will return this field, indicating the timestamp, in the format of "[[100,200], [200,500]]"
|
||||
`stamp_sents`:If AM is a timestamp model, it will return this field, indicating the stamp_sents, in the format of [{"text_seg":"正 是 因 为","punc":",","start":"430","end":"1130","ts_list":[[430,670],[670,810],[810,1030],[1030,1130]]}]
|
||||
```
|
||||
|
||||
@ -37,7 +37,7 @@ pcm直接将音频数据,其他格式音频数据,连同头部信息与音
|
||||
#### 发送识别结果
|
||||
message为(采用json序列化)
|
||||
```text
|
||||
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True,"timestamp":"[[100,200], [200,500]]"}
|
||||
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True,"timestamp":"[[100,200], [200,500]]","stamp_sents":[]}
|
||||
```
|
||||
参数介绍:
|
||||
```text
|
||||
@ -46,6 +46,7 @@ message为(采用json序列化)
|
||||
`text`:表示语音识别输出文本
|
||||
`is_final`:表示识别结束
|
||||
`timestamp`:如果AM为时间戳模型,会返回此字段,表示时间戳,格式为 "[[100,200], [200,500]]"(ms)
|
||||
`stamp_sents`:如果AM为时间戳模型,会返回此字段,表示句子级别时间戳,格式为 [{"text_seg":"正 是 因 为","punc":",","start":"430","end":"1130","ts_list":[[430,670],[670,810],[810,1030],[1030,1130]]}]
|
||||
```
|
||||
|
||||
## 实时语音识别
|
||||
@ -86,7 +87,7 @@ message为(需要用json序列化):
|
||||
#### 发送识别结果
|
||||
message为(采用json序列化)
|
||||
```text
|
||||
{"mode": "2pass-online", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]"}
|
||||
{"mode": "2pass-online", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]","stamp_sents":[]}
|
||||
```
|
||||
参数介绍:
|
||||
```text
|
||||
@ -95,4 +96,5 @@ message为(采用json序列化)
|
||||
`text`:表示语音识别输出文本
|
||||
`is_final`:表示识别结束
|
||||
`timestamp`:如果AM为时间戳模型,会返回此字段,表示时间戳,格式为 "[[100,200], [200,500]]"(ms)
|
||||
`stamp_sents`:如果AM为时间戳模型,会返回此字段,表示句子级别时间戳,格式为 [{"text_seg":"正 是 因 为","punc":",","start":"430","end":"1130","ts_list":[[430,670],[670,810],[810,1030],[1030,1130]]}]
|
||||
```
|
||||
|
||||
@ -43,13 +43,18 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std:
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
|
||||
void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids,
|
||||
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_) {
|
||||
void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids, int audio_fs,
|
||||
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_,
|
||||
float glob_beam, float lat_beam, float am_scale, int inc_bias, unordered_map<string, int> hws_map) {
|
||||
|
||||
struct timeval start, end;
|
||||
long seconds = 0;
|
||||
float n_total_length = 0.0f;
|
||||
long n_total_time = 0;
|
||||
|
||||
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_scale);
|
||||
// load hotwords list and build graph
|
||||
FunWfstDecoderLoadHwsRes(decoder_handle, inc_bias, hws_map);
|
||||
|
||||
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
|
||||
|
||||
@ -59,7 +64,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
// warm up
|
||||
for (size_t i = 0; i < 2; i++)
|
||||
{
|
||||
int32_t sampling_rate_ = 16000;
|
||||
int32_t sampling_rate_ = audio_fs;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[0].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
|
||||
@ -90,7 +95,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
|
||||
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
|
||||
if (result)
|
||||
{
|
||||
FunASRFreeResult(result);
|
||||
@ -104,7 +110,7 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
if (i >= wav_list.size()) {
|
||||
break;
|
||||
}
|
||||
int32_t sampling_rate_ = 16000;
|
||||
int32_t sampling_rate_ = audio_fs;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[i].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
|
||||
@ -139,7 +145,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
|
||||
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -197,6 +204,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
|
||||
*total_time = n_total_time;
|
||||
}
|
||||
}
|
||||
FunWfstDecoderUnloadHwsRes(decoder_handle);
|
||||
FunASRWfstDecoderUninit(decoder_handle);
|
||||
FunTpassOnlineUninit(tpass_online_handle);
|
||||
}
|
||||
|
||||
@ -215,11 +224,17 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
|
||||
|
||||
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num_("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||
|
||||
cmd.add(offline_model_dir);
|
||||
@ -230,7 +245,13 @@ int main(int argc, char** argv)
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(asr_mode);
|
||||
cmd.add(onnx_thread);
|
||||
cmd.add(thread_num_);
|
||||
@ -246,6 +267,7 @@ int main(int argc, char** argv)
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(asr_mode, ASR_MODE, model_path);
|
||||
|
||||
@ -270,6 +292,14 @@ int main(int argc, char** argv)
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
float glob_beam = 3.0f;
|
||||
float lat_beam = 3.0f;
|
||||
float am_sc = 10.0f;
|
||||
if (lm_dir.isSet()) {
|
||||
glob_beam = global_beam.getValue();
|
||||
lat_beam = lattice_beam.getValue();
|
||||
am_sc = am_scale.getValue();
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
@ -319,7 +349,8 @@ int main(int argc, char** argv)
|
||||
int rtf_threds = thread_num_.getValue();
|
||||
for (int i = 0; i < rtf_threds; i++)
|
||||
{
|
||||
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_));
|
||||
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_,
|
||||
glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map));
|
||||
}
|
||||
|
||||
for (auto& thread : threads)
|
||||
|
||||
@ -51,10 +51,16 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
|
||||
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||
|
||||
cmd.add(offline_model_dir);
|
||||
@ -64,8 +70,14 @@ int main(int argc, char** argv)
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(asr_mode);
|
||||
cmd.add(onnx_thread);
|
||||
cmd.add(hotword);
|
||||
@ -79,6 +91,7 @@ int main(int argc, char** argv)
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(asr_mode, ASR_MODE, model_path);
|
||||
@ -104,6 +117,16 @@ int main(int argc, char** argv)
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
float glob_beam = 3.0f;
|
||||
float lat_beam = 3.0f;
|
||||
float am_sc = 10.0f;
|
||||
if (lm_dir.isSet()) {
|
||||
glob_beam = global_beam.getValue();
|
||||
lat_beam = lattice_beam.getValue();
|
||||
am_sc = am_scale.getValue();
|
||||
}
|
||||
// init wfst decoder
|
||||
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc);
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
@ -144,6 +167,9 @@ int main(int argc, char** argv)
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
// load hotwords list and build graph
|
||||
FunWfstDecoderLoadHwsRes(decoder_handle, fst_inc_wts.getValue(), hws_map);
|
||||
|
||||
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
|
||||
// init online features
|
||||
std::vector<int> chunk_size = {5,10,5};
|
||||
@ -154,7 +180,7 @@ int main(int argc, char** argv)
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
|
||||
int32_t sampling_rate_ = 16000;
|
||||
int32_t sampling_rate_ = audio_fs.getValue();
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_file.c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
@ -189,7 +215,9 @@ int main(int argc, char** argv)
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
|
||||
speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm",
|
||||
(ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -233,10 +261,12 @@ int main(int argc, char** argv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
FunWfstDecoderUnloadHwsRes(decoder_handle);
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FunASRWfstDecoderUninit(decoder_handle);
|
||||
FunTpassOnlineUninit(tpass_online_handle);
|
||||
FunTpassUninit(tpass_handle);
|
||||
return 0;
|
||||
|
||||
@ -29,7 +29,7 @@ using namespace std;
|
||||
std::atomic<int> wav_index(0);
|
||||
std::mutex mtx;
|
||||
|
||||
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
|
||||
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids, int audio_fs,
|
||||
float* total_length, long* total_time, int core_id, float glob_beam = 3.0f, float lat_beam = 3.0f, float am_sc = 10.0f,
|
||||
int fst_inc_wts = 20, string hotword_path = "") {
|
||||
|
||||
@ -54,8 +54,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
// warm up
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
{
|
||||
FunOfflineReset(asr_handle, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle);
|
||||
if(result){
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
@ -69,7 +68,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
}
|
||||
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle);
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
@ -83,6 +82,10 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
if(stamp !=""){
|
||||
LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << stamp;
|
||||
}
|
||||
string stamp_sents = FunASRGetStampSents(result);
|
||||
if(stamp_sents !=""){
|
||||
LOG(INFO)<< wav_ids[i] <<" : "<<stamp_sents;
|
||||
}
|
||||
float snippet_time = FunASRGetRetSnippetTime(result);
|
||||
n_total_length += snippet_time;
|
||||
FunASRFreeResult(result);
|
||||
@ -138,6 +141,7 @@ int main(int argc, char *argv[])
|
||||
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||
|
||||
@ -155,6 +159,7 @@ int main(int argc, char *argv[])
|
||||
cmd.add(hotword);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(thread_num);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
@ -234,7 +239,7 @@ int main(int argc, char *argv[])
|
||||
}
|
||||
for (int i = 0; i < rtf_threds; i++)
|
||||
{
|
||||
threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i, glob_beam, lat_beam, am_sc, value_bias, hotword_path));
|
||||
threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, glob_beam, lat_beam, am_sc, value_bias, hotword_path));
|
||||
}
|
||||
|
||||
for (auto& thread : threads)
|
||||
|
||||
@ -68,10 +68,12 @@ int main(int argc, char *argv[])
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
@ -131,7 +133,7 @@ int main(int argc, char *argv[])
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
|
||||
FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, audio_fs.getValue());
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
@ -50,13 +50,14 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml ", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
|
||||
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||
|
||||
cmd.add(model_dir);
|
||||
@ -72,6 +73,7 @@ int main(int argc, char** argv)
|
||||
cmd.add(am_scale);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(hotword);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
@ -157,7 +159,7 @@ int main(int argc, char** argv)
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -170,6 +172,10 @@ int main(int argc, char** argv)
|
||||
if(stamp !=""){
|
||||
LOG(INFO)<< wav_id <<" : "<<stamp;
|
||||
}
|
||||
string stamp_sents = FunASRGetStampSents(result);
|
||||
if(stamp_sents !=""){
|
||||
LOG(INFO)<< wav_id <<" : "<<stamp_sents;
|
||||
}
|
||||
snippet_time += FunASRGetRetSnippetTime(result);
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
|
||||
@ -49,10 +49,12 @@ int main(int argc, char *argv[])
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
@ -110,7 +112,7 @@ int main(int argc, char *argv[])
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
|
||||
int32_t sampling_rate_ = -1;
|
||||
int32_t sampling_rate_ = audio_fs.getValue();
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_file.c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
@ -143,7 +145,7 @@ int main(int argc, char *argv[])
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, sampling_rate_);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
@ -38,7 +38,7 @@ bool is_target_file(const std::string& filename, const std::string target) {
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
|
||||
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids, int audio_fs,
|
||||
float* total_length, long* total_time, int core_id) {
|
||||
|
||||
struct timeval start, end;
|
||||
@ -52,7 +52,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
// warm up
|
||||
for (size_t i = 0; i < 10; i++)
|
||||
{
|
||||
int32_t sampling_rate_ = -1;
|
||||
int32_t sampling_rate_ = audio_fs;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[0].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
|
||||
@ -84,7 +84,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, sampling_rate_);
|
||||
if (result)
|
||||
{
|
||||
FunASRFreeResult(result);
|
||||
@ -98,7 +98,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
if (i >= wav_list.size()) {
|
||||
break;
|
||||
}
|
||||
int32_t sampling_rate_ = -1;
|
||||
int32_t sampling_rate_ = audio_fs;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[i].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
|
||||
@ -131,7 +131,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, sampling_rate_);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
@ -186,6 +186,7 @@ int main(int argc, char *argv[])
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
@ -195,6 +196,7 @@ int main(int argc, char *argv[])
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(thread_num);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
@ -260,7 +262,7 @@ int main(int argc, char *argv[])
|
||||
int rtf_threds = thread_num.getValue();
|
||||
for (int i = 0; i < rtf_threds; i++)
|
||||
{
|
||||
threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
|
||||
threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i));
|
||||
}
|
||||
|
||||
for (auto& thread : threads)
|
||||
|
||||
@ -75,10 +75,12 @@ int main(int argc, char *argv[])
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
@ -139,10 +141,9 @@ int main(int argc, char *argv[])
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
|
||||
int32_t sampling_rate_ = -1;
|
||||
int32_t sampling_rate_ = audio_fs.getValue();
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_file.c_str(), "wav")){
|
||||
int32_t sampling_rate_ = -1;
|
||||
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
@ -170,7 +171,7 @@ int main(int argc, char *argv[])
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
|
||||
FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, sampling_rate_);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
@ -52,18 +52,20 @@ class Audio {
|
||||
queue<AudioFrame *> frame_queue;
|
||||
queue<AudioFrame *> asr_online_queue;
|
||||
queue<AudioFrame *> asr_offline_queue;
|
||||
|
||||
int dest_sample_rate;
|
||||
public:
|
||||
Audio(int data_type);
|
||||
Audio(int data_type, int size);
|
||||
Audio(int model_sample_rate,int data_type);
|
||||
Audio(int model_sample_rate,int data_type, int size);
|
||||
~Audio();
|
||||
void ClearQueue(std::queue<AudioFrame*>& q);
|
||||
void Disp();
|
||||
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||
bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
|
||||
bool LoadWav(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadWav(const char* filename, int32_t* sampling_rate, bool resample=true);
|
||||
bool LoadWav2Char(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
|
||||
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadPcmwav(const char* filename, int32_t* sampling_rate, bool resample=true);
|
||||
bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadOthers2Char(const char* filename);
|
||||
bool FfmpegLoad(const char *filename, bool copy2char=false);
|
||||
|
||||
@ -34,6 +34,7 @@ namespace funasr {
|
||||
#define THREAD_NUM "thread-num"
|
||||
#define PORT_ID "port-id"
|
||||
#define HOTWORD_SEP " "
|
||||
#define AUDIO_FS "audio-fs"
|
||||
|
||||
// #define VAD_MODEL_PATH "vad-model"
|
||||
// #define VAD_CMVN_PATH "vad-cmvn"
|
||||
@ -68,6 +69,7 @@ namespace funasr {
|
||||
#define QUANT_DECODER_NAME "decoder_quant.onnx"
|
||||
|
||||
#define LM_FST_RES "TLG.fst"
|
||||
#define LEX_PATH "lexicon.txt"
|
||||
|
||||
// vad
|
||||
#ifndef VAD_SILENCE_DURATION
|
||||
|
||||
@ -68,6 +68,7 @@ _FUNASRAPI FUNASR_RESULT FunASRInfer(FUNASR_HANDLE handle, const char* sz_filena
|
||||
|
||||
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
|
||||
_FUNASRAPI const char* FunASRGetStamp(FUNASR_RESULT result);
|
||||
_FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result);
|
||||
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
|
||||
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
|
||||
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
|
||||
@ -118,7 +119,7 @@ _FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
|
||||
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true,
|
||||
int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS,
|
||||
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true);
|
||||
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
|
||||
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ class Model {
|
||||
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
|
||||
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
|
||||
virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
|
||||
virtual void InitLm(const std::string &lm_file, const std::string &lm_config){};
|
||||
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
|
||||
virtual void InitFstDecoder(){};
|
||||
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
|
||||
virtual std::string Rescoring() = 0;
|
||||
@ -23,6 +23,8 @@ class Model {
|
||||
virtual void InitSegDict(const std::string &seg_dict_model){};
|
||||
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
|
||||
virtual std::string GetLang(){return "";};
|
||||
virtual int GetAsrSampleRate() = 0;
|
||||
|
||||
};
|
||||
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
|
||||
|
||||
@ -12,6 +12,7 @@ class VadModel {
|
||||
virtual ~VadModel(){};
|
||||
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
|
||||
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
|
||||
virtual int GetVadSampleRate() = 0;
|
||||
};
|
||||
|
||||
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
|
||||
@ -38,12 +38,12 @@ make -j 4
|
||||
### Download onnxruntime
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-win-x64-1.16.1.zip
|
||||
|
||||
Download and unzip to d:/ffmpeg-master-latest-win64-gpl-shared
|
||||
Download and unzip to d:/onnxruntime-win-x64-1.16.1
|
||||
|
||||
### Download ffmpeg
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-win64-gpl-shared.zip
|
||||
|
||||
Download and unzip to d:/onnxruntime-win-x64-1.16.1
|
||||
Download and unzip to d:/ffmpeg-master-latest-win64-gpl-shared
|
||||
|
||||
### Build runtime
|
||||
```
|
||||
|
||||
@ -193,18 +193,28 @@ int AudioFrame::Disp()
|
||||
return 0;
|
||||
}
|
||||
|
||||
Audio::Audio(int data_type) : data_type(data_type)
|
||||
Audio::Audio(int data_type) : dest_sample_rate(MODEL_SAMPLE_RATE), data_type(data_type)
|
||||
{
|
||||
speech_buff = NULL;
|
||||
speech_data = NULL;
|
||||
align_size = 1360;
|
||||
seg_sample = dest_sample_rate / 1000;
|
||||
}
|
||||
|
||||
Audio::Audio(int data_type, int size) : data_type(data_type)
|
||||
Audio::Audio(int model_sample_rate, int data_type) : dest_sample_rate(model_sample_rate), data_type(data_type)
|
||||
{
|
||||
speech_buff = NULL;
|
||||
speech_data = NULL;
|
||||
align_size = 1360;
|
||||
seg_sample = dest_sample_rate / 1000;
|
||||
}
|
||||
|
||||
Audio::Audio(int model_sample_rate, int data_type, int size) : dest_sample_rate(model_sample_rate), data_type(data_type)
|
||||
{
|
||||
speech_buff = NULL;
|
||||
speech_data = NULL;
|
||||
align_size = (float)size;
|
||||
seg_sample = dest_sample_rate / 1000;
|
||||
}
|
||||
|
||||
Audio::~Audio()
|
||||
@ -218,32 +228,43 @@ Audio::~Audio()
|
||||
if (speech_char != NULL) {
|
||||
free(speech_char);
|
||||
}
|
||||
ClearQueue(frame_queue);
|
||||
ClearQueue(asr_online_queue);
|
||||
ClearQueue(asr_offline_queue);
|
||||
}
|
||||
|
||||
void Audio::ClearQueue(std::queue<AudioFrame*>& q) {
|
||||
while (!q.empty()) {
|
||||
AudioFrame* frame = q.front();
|
||||
delete frame;
|
||||
q.pop();
|
||||
}
|
||||
}
|
||||
|
||||
void Audio::Disp()
|
||||
{
|
||||
LOG(INFO) << "Audio time is " << (float)speech_len / MODEL_SAMPLE_RATE << " s. len is " << speech_len;
|
||||
LOG(INFO) << "Audio time is " << (float)speech_len / dest_sample_rate << " s. len is " << speech_len;
|
||||
}
|
||||
|
||||
float Audio::GetTimeLen()
|
||||
{
|
||||
return (float)speech_len / MODEL_SAMPLE_RATE;
|
||||
return (float)speech_len / dest_sample_rate;
|
||||
}
|
||||
|
||||
void Audio::WavResample(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n)
|
||||
{
|
||||
LOG(INFO) << "Creating a resampler:\n"
|
||||
<< " in_sample_rate: "<< sampling_rate << "\n"
|
||||
<< " output_sample_rate: " << static_cast<int32_t>(MODEL_SAMPLE_RATE);
|
||||
LOG(INFO) << "Creating a resampler: "
|
||||
<< " in_sample_rate: "<< sampling_rate
|
||||
<< " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
|
||||
float min_freq =
|
||||
std::min<int32_t>(sampling_rate, MODEL_SAMPLE_RATE);
|
||||
std::min<int32_t>(sampling_rate, dest_sample_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
|
||||
auto resampler = std::make_unique<LinearResample>(
|
||||
sampling_rate, MODEL_SAMPLE_RATE, lowpass_cutoff, lowpass_filter_width);
|
||||
sampling_rate, dest_sample_rate, lowpass_cutoff, lowpass_filter_width);
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
//reset speech_data
|
||||
@ -311,7 +332,7 @@ bool Audio::FfmpegLoad(const char *filename, bool copy2char){
|
||||
nullptr, // allocate a new context
|
||||
AV_CH_LAYOUT_MONO, // output channel layout (stereo)
|
||||
AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
|
||||
16000, // output sample rate (same as input)
|
||||
dest_sample_rate, // output sample rate (same as input)
|
||||
av_get_default_channel_layout(codecContext->channels), // input channel layout
|
||||
codecContext->sample_fmt, // input sample format
|
||||
codecContext->sample_rate, // input sample rate
|
||||
@ -344,30 +365,28 @@ bool Audio::FfmpegLoad(const char *filename, bool copy2char){
|
||||
while (avcodec_receive_frame(codecContext, frame) >= 0) {
|
||||
// Resample audio if necessary
|
||||
std::vector<uint8_t> resampled_buffer;
|
||||
int in_samples = frame->nb_samples;
|
||||
uint8_t **in_data = frame->extended_data;
|
||||
int out_samples = av_rescale_rnd(in_samples,
|
||||
16000,
|
||||
int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, codecContext->sample_rate) + frame->nb_samples,
|
||||
dest_sample_rate,
|
||||
codecContext->sample_rate,
|
||||
AV_ROUND_DOWN);
|
||||
|
||||
int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
|
||||
if (resampled_buffer.size() < resampled_size) {
|
||||
resampled_buffer.resize(resampled_size);
|
||||
}
|
||||
}
|
||||
uint8_t *resampled_data = resampled_buffer.data();
|
||||
int ret = swr_convert(
|
||||
swr_ctx,
|
||||
&resampled_data, // output buffer
|
||||
resampled_size, // output buffer size
|
||||
(const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
|
||||
in_samples // input buffer size
|
||||
out_samples, // output buffer size
|
||||
(const uint8_t **)(frame->data), // choose channel
|
||||
frame->nb_samples // input buffer size
|
||||
);
|
||||
if (ret < 0) {
|
||||
LOG(ERROR) << "Error resampling audio";
|
||||
break;
|
||||
}
|
||||
std::copy(resampled_buffer.begin(), resampled_buffer.end(), std::back_inserter(resampled_buffers));
|
||||
resampled_buffers.insert(resampled_buffers.end(), resampled_buffer.begin(), resampled_buffer.begin() + resampled_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -443,6 +462,10 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
|
||||
nullptr, // write callback (not used here)
|
||||
nullptr // seek callback (not used here)
|
||||
);
|
||||
if (!avio_ctx) {
|
||||
av_free(buf_copy);
|
||||
return false;
|
||||
}
|
||||
AVFormatContext* formatContext = avformat_alloc_context();
|
||||
formatContext->pb = avio_ctx;
|
||||
if (avformat_open_input(&formatContext, "", NULL, NULL) != 0) {
|
||||
@ -494,7 +517,7 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
|
||||
nullptr, // allocate a new context
|
||||
AV_CH_LAYOUT_MONO, // output channel layout (stereo)
|
||||
AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
|
||||
16000, // output sample rate (same as input)
|
||||
dest_sample_rate, // output sample rate (same as input)
|
||||
av_get_default_channel_layout(codecContext->channels), // input channel layout
|
||||
codecContext->sample_fmt, // input sample format
|
||||
codecContext->sample_rate, // input sample rate
|
||||
@ -529,37 +552,37 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
|
||||
while (avcodec_receive_frame(codecContext, frame) >= 0) {
|
||||
// Resample audio if necessary
|
||||
std::vector<uint8_t> resampled_buffer;
|
||||
int in_samples = frame->nb_samples;
|
||||
uint8_t **in_data = frame->extended_data;
|
||||
int out_samples = av_rescale_rnd(in_samples,
|
||||
16000,
|
||||
int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, codecContext->sample_rate) + frame->nb_samples,
|
||||
dest_sample_rate,
|
||||
codecContext->sample_rate,
|
||||
AV_ROUND_DOWN);
|
||||
|
||||
int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
|
||||
if (resampled_buffer.size() < resampled_size) {
|
||||
resampled_buffer.resize(resampled_size);
|
||||
}
|
||||
}
|
||||
uint8_t *resampled_data = resampled_buffer.data();
|
||||
int ret = swr_convert(
|
||||
swr_ctx,
|
||||
&resampled_data, // output buffer
|
||||
resampled_size, // output buffer size
|
||||
(const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
|
||||
in_samples // input buffer size
|
||||
out_samples, // output buffer size
|
||||
(const uint8_t **)(frame->data), // choose channel: channel_data
|
||||
frame->nb_samples // input buffer size
|
||||
);
|
||||
if (ret < 0) {
|
||||
LOG(ERROR) << "Error resampling audio";
|
||||
break;
|
||||
}
|
||||
std::copy(resampled_buffer.begin(), resampled_buffer.end(), std::back_inserter(resampled_buffers));
|
||||
resampled_buffers.insert(resampled_buffers.end(), resampled_buffer.begin(), resampled_buffer.begin() + resampled_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
av_packet_unref(packet);
|
||||
}
|
||||
|
||||
avio_context_free(&avio_ctx);
|
||||
//avio_context_free(&avio_ctx);
|
||||
av_freep(&avio_ctx ->buffer);
|
||||
av_freep(&avio_ctx);
|
||||
avformat_close_input(&formatContext);
|
||||
avformat_free_context(formatContext);
|
||||
avcodec_free_context(&codecContext);
|
||||
@ -604,7 +627,7 @@ bool Audio::FfmpegLoad(const char* buf, int n_file_len){
|
||||
}
|
||||
|
||||
|
||||
bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
|
||||
bool Audio::LoadWav(const char *filename, int32_t* sampling_rate, bool resample)
|
||||
{
|
||||
WaveHeader header;
|
||||
if (speech_data != NULL) {
|
||||
@ -666,7 +689,7 @@ bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != MODEL_SAMPLE_RATE){
|
||||
if(resample && *sampling_rate != dest_sample_rate){
|
||||
WavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
@ -752,7 +775,7 @@ bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != MODEL_SAMPLE_RATE){
|
||||
if(*sampling_rate != dest_sample_rate){
|
||||
WavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
@ -795,7 +818,7 @@ bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != MODEL_SAMPLE_RATE){
|
||||
if(*sampling_rate != dest_sample_rate){
|
||||
WavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
@ -840,7 +863,7 @@ bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_r
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != MODEL_SAMPLE_RATE){
|
||||
if(*sampling_rate != dest_sample_rate){
|
||||
WavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
@ -857,7 +880,7 @@ bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_r
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
|
||||
bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate, bool resample)
|
||||
{
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
@ -898,7 +921,7 @@ bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != MODEL_SAMPLE_RATE){
|
||||
if(resample && *sampling_rate != dest_sample_rate){
|
||||
WavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
@ -1009,7 +1032,7 @@ int Audio::Fetch(float *&dout, int &len, int &flag, float &start_time)
|
||||
AudioFrame *frame = frame_queue.front();
|
||||
frame_queue.pop();
|
||||
|
||||
start_time = (float)(frame->GetStart())/MODEL_SAMPLE_RATE;
|
||||
start_time = (float)(frame->GetStart())/ dest_sample_rate;
|
||||
dout = speech_data + frame->GetStart();
|
||||
len = frame->GetLen();
|
||||
delete frame;
|
||||
@ -1248,7 +1271,7 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
|
||||
}
|
||||
|
||||
// erase all_samples
|
||||
int vector_cache = MODEL_SAMPLE_RATE*2;
|
||||
int vector_cache = dest_sample_rate*2;
|
||||
if(speech_offline_start == -1){
|
||||
if(all_samples.size() > vector_cache){
|
||||
int erase_num = all_samples.size() - vector_cache;
|
||||
|
||||
@ -65,12 +65,17 @@ class BiasLm {
|
||||
if (text.size() > 1) {
|
||||
score = std::stof(text[1]);
|
||||
}
|
||||
Utf8ToCharset(text[0], split_str);
|
||||
SplitChiEngCharacters(text[0], split_str);
|
||||
for (auto &str : split_str) {
|
||||
split_id.push_back(phn_set_.String2Id(str));
|
||||
if (!phn_set_.Find(str)) {
|
||||
is_oov = true;
|
||||
break;
|
||||
std::vector<string> lex_vec;
|
||||
std::string lex_str = vocab_.Word2Lex(str);
|
||||
SplitStringToVector(lex_str, " ", true, &lex_vec);
|
||||
for (auto &token : lex_vec) {
|
||||
split_id.push_back(phn_set_.String2Id(token));
|
||||
if (!phn_set_.Find(token)) {
|
||||
is_oov = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_oov) {
|
||||
@ -103,12 +108,17 @@ class BiasLm {
|
||||
std::vector<std::string> split_str;
|
||||
std::vector<int> split_id;
|
||||
score = kv.second;
|
||||
Utf8ToCharset(kv.first, split_str);
|
||||
SplitChiEngCharacters(kv.first, split_str);
|
||||
for (auto &str : split_str) {
|
||||
split_id.push_back(phn_set_.String2Id(str));
|
||||
if (!phn_set_.Find(str)) {
|
||||
is_oov = true;
|
||||
break;
|
||||
std::vector<string> lex_vec;
|
||||
std::string lex_str = vocab_.Word2Lex(str);
|
||||
SplitStringToVector(lex_str, " ", true, &lex_vec);
|
||||
for (auto &token : lex_vec) {
|
||||
split_id.push_back(phn_set_.String2Id(token));
|
||||
if (!phn_set_.Find(token)) {
|
||||
is_oov = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_oov) {
|
||||
|
||||
@ -9,6 +9,7 @@ typedef struct
|
||||
{
|
||||
std::string msg;
|
||||
std::string stamp;
|
||||
std::string stamp_sents;
|
||||
std::string tpass_msg;
|
||||
float snippet_time;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
|
||||
@ -187,8 +187,11 @@ void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
|
||||
vad_max_len_ = vad_max_len;
|
||||
vad_speech_noise_thres_ = vad_speech_noise_thres;
|
||||
|
||||
frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
|
||||
frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
|
||||
|
||||
// 2pass
|
||||
audio_handle = make_unique<Audio>(1);
|
||||
audio_handle = make_unique<Audio>(vad_sample_rate,1);
|
||||
}
|
||||
|
||||
FsmnVadOnline::~FsmnVadOnline() {
|
||||
|
||||
@ -21,6 +21,8 @@ public:
|
||||
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
|
||||
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
|
||||
void Reset();
|
||||
int GetVadSampleRate() { return vad_sample_rate_; };
|
||||
|
||||
// 2pass
|
||||
std::unique_ptr<Audio> audio_handle = nullptr;
|
||||
|
||||
|
||||
@ -28,6 +28,8 @@ public:
|
||||
std::vector<std::vector<float>> *in_cache,
|
||||
bool is_final);
|
||||
void Reset();
|
||||
|
||||
int GetVadSampleRate() { return vad_sample_rate_; };
|
||||
|
||||
std::shared_ptr<Ort::Session> vad_session_ = nullptr;
|
||||
Ort::Env env_;
|
||||
|
||||
@ -57,7 +57,7 @@
|
||||
if (!recog_obj)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
|
||||
if(wav_format == "pcm" || wav_format == "PCM"){
|
||||
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
|
||||
return nullptr;
|
||||
@ -93,7 +93,7 @@
|
||||
if (!recog_obj)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
|
||||
if(funasr::is_target_file(sz_filename, "wav")){
|
||||
int32_t sampling_rate_ = -1;
|
||||
if(!audio.LoadWav(sz_filename, &sampling_rate_))
|
||||
@ -134,7 +134,7 @@
|
||||
if (!vad_obj)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
|
||||
if(wav_format == "pcm" || wav_format == "PCM"){
|
||||
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
|
||||
return nullptr;
|
||||
@ -146,6 +146,7 @@
|
||||
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
if(p_result->snippet_time == 0){
|
||||
p_result->segments = new vector<std::vector<int>>();
|
||||
return p_result;
|
||||
}
|
||||
|
||||
@ -162,7 +163,7 @@
|
||||
if (!vad_obj)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
|
||||
if(funasr::is_target_file(sz_filename, "wav")){
|
||||
int32_t sampling_rate_ = -1;
|
||||
if(!audio.LoadWav(sz_filename, &sampling_rate_))
|
||||
@ -178,6 +179,7 @@
|
||||
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
if(p_result->snippet_time == 0){
|
||||
p_result->segments = new vector<std::vector<int>>();
|
||||
return p_result;
|
||||
}
|
||||
|
||||
@ -222,7 +224,7 @@
|
||||
if (!offline_stream)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
funasr::Audio audio(offline_stream->asr_handle->GetAsrSampleRate(),1);
|
||||
try{
|
||||
if(wav_format == "pcm" || wav_format == "PCM"){
|
||||
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
|
||||
@ -294,10 +296,18 @@
|
||||
#if !defined(__APPLE__)
|
||||
if(offline_stream->UseITN() && itn){
|
||||
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
|
||||
if(!(p_result->stamp).empty()){
|
||||
std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
|
||||
if(!new_stamp.empty()){
|
||||
p_result->stamp = new_stamp;
|
||||
}
|
||||
}
|
||||
p_result->msg = msg_itn;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!(p_result->stamp).empty()){
|
||||
p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
|
||||
}
|
||||
return p_result;
|
||||
}
|
||||
|
||||
@ -308,7 +318,7 @@
|
||||
if (!offline_stream)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
funasr::Audio audio((offline_stream->asr_handle)->GetAsrSampleRate(),1);
|
||||
try{
|
||||
if(funasr::is_target_file(sz_filename, "wav")){
|
||||
int32_t sampling_rate_ = -1;
|
||||
@ -384,9 +394,18 @@
|
||||
#if !defined(__APPLE__)
|
||||
if(offline_stream->UseITN() && itn){
|
||||
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
|
||||
if(!(p_result->stamp).empty()){
|
||||
std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
|
||||
if(!new_stamp.empty()){
|
||||
p_result->stamp = new_stamp;
|
||||
}
|
||||
}
|
||||
p_result->msg = msg_itn;
|
||||
}
|
||||
#endif
|
||||
if (!(p_result->stamp).empty()){
|
||||
p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
|
||||
}
|
||||
return p_result;
|
||||
}
|
||||
|
||||
@ -420,7 +439,7 @@
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
|
||||
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
|
||||
int sampling_rate, std::string wav_format, ASR_TYPE mode,
|
||||
const std::vector<std::vector<float>> &hw_emb, bool itn)
|
||||
const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
|
||||
{
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
|
||||
@ -494,7 +513,12 @@
|
||||
// timestamp
|
||||
std::string cur_stamp = "[";
|
||||
while(audio->FetchTpass(frame) > 0){
|
||||
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
|
||||
// dec reset
|
||||
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
|
||||
if (wfst_decoder){
|
||||
wfst_decoder->StartUtterance();
|
||||
}
|
||||
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
|
||||
|
||||
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
|
||||
if(msg_vec.size()==0){
|
||||
@ -524,10 +548,19 @@
|
||||
#if !defined(__APPLE__)
|
||||
if(tpass_stream->UseITN() && itn){
|
||||
string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
|
||||
// TimestampSmooth
|
||||
if(!(p_result->stamp).empty()){
|
||||
std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
|
||||
if(!new_stamp.empty()){
|
||||
p_result->stamp = new_stamp;
|
||||
}
|
||||
}
|
||||
p_result->tpass_msg = msg_itn;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!(p_result->stamp).empty()){
|
||||
p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
|
||||
}
|
||||
if(frame != NULL){
|
||||
delete frame;
|
||||
frame = NULL;
|
||||
@ -584,6 +617,15 @@
|
||||
return p_result->stamp.c_str();
|
||||
}
|
||||
|
||||
_FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result)
|
||||
{
|
||||
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
|
||||
if(!p_result)
|
||||
return nullptr;
|
||||
|
||||
return p_result->stamp_sents.c_str();
|
||||
}
|
||||
|
||||
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
|
||||
{
|
||||
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
|
||||
@ -726,9 +768,15 @@
|
||||
if (asr_type == ASR_OFFLINE) {
|
||||
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
|
||||
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
|
||||
if (paraformer->lm_)
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
} else if (asr_type == ASR_TWO_PASS){
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
|
||||
if (paraformer->lm_)
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
|
||||
@ -63,10 +63,16 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
||||
|
||||
// Lm resource
|
||||
if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
|
||||
string fst_path, lm_config_path, hws_path;
|
||||
string fst_path, lm_config_path, lex_path;
|
||||
fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
|
||||
lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
|
||||
asr_handle->InitLm(fst_path, lm_config_path);
|
||||
lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
|
||||
if (access(lex_path.c_str(), F_OK) != 0 )
|
||||
{
|
||||
LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
|
||||
}else{
|
||||
asr_handle->InitLm(fst_path, lm_config_path, lex_path);
|
||||
}
|
||||
}
|
||||
|
||||
// PUNC model
|
||||
|
||||
@ -61,7 +61,11 @@ void ParaformerOnline::InitOnline(
|
||||
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
|
||||
fsmn_init_cache_.emplace_back(0);
|
||||
}
|
||||
chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
|
||||
chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
|
||||
|
||||
frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
|
||||
frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
|
||||
|
||||
}
|
||||
|
||||
void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
|
||||
@ -489,7 +493,7 @@ string ParaformerOnline::Forward(float* din, int len, bool input_finished, const
|
||||
if(is_first_chunk){
|
||||
is_first_chunk = false;
|
||||
}
|
||||
ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
|
||||
ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
|
||||
if(wav_feats.size() == 0){
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -111,6 +111,9 @@ namespace funasr {
|
||||
string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
|
||||
string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
|
||||
string Rescoring();
|
||||
|
||||
int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
|
||||
|
||||
// 2pass
|
||||
std::string online_res;
|
||||
int chunk_len;
|
||||
|
||||
@ -19,10 +19,11 @@ Paraformer::Paraformer()
|
||||
|
||||
// offline
|
||||
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
|
||||
LoadConfigFromYaml(am_config.c_str());
|
||||
// knf options
|
||||
fbank_opts_.frame_opts.dither = 0;
|
||||
fbank_opts_.mel_opts.num_bins = n_mels;
|
||||
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
|
||||
fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
|
||||
fbank_opts_.frame_opts.window_type = window_type;
|
||||
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
|
||||
fbank_opts_.frame_opts.frame_length_ms = frame_length;
|
||||
@ -65,7 +66,6 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
|
||||
for (auto& item : m_strOutputNames)
|
||||
m_szOutputNames.push_back(item.c_str());
|
||||
vocab = new Vocab(am_config.c_str());
|
||||
LoadConfigFromYaml(am_config.c_str());
|
||||
phone_set_ = new PhoneSet(am_config.c_str());
|
||||
LoadCmvn(am_cmvn.c_str());
|
||||
}
|
||||
@ -77,7 +77,7 @@ void Paraformer::InitAsr(const std::string &en_model, const std::string &de_mode
|
||||
// knf options
|
||||
fbank_opts_.frame_opts.dither = 0;
|
||||
fbank_opts_.mel_opts.num_bins = n_mels;
|
||||
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
|
||||
fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
|
||||
fbank_opts_.frame_opts.window_type = window_type;
|
||||
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
|
||||
fbank_opts_.frame_opts.frame_length_ms = frame_length;
|
||||
@ -187,13 +187,13 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &en_mode
|
||||
}
|
||||
|
||||
void Paraformer::InitLm(const std::string &lm_file,
|
||||
const std::string &lm_cfg_file) {
|
||||
const std::string &lm_cfg_file,
|
||||
const std::string &lex_file) {
|
||||
try {
|
||||
lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
|
||||
fst::Fst<fst::StdArc>::Read(lm_file));
|
||||
if (lm_){
|
||||
if (vocab) { delete vocab; }
|
||||
vocab = new Vocab(lm_cfg_file.c_str());
|
||||
lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
|
||||
LOG(INFO) << "Successfully load lm file " << lm_file;
|
||||
}else{
|
||||
LOG(ERROR) << "Failed to load lm file " << lm_file;
|
||||
@ -215,6 +215,9 @@ void Paraformer::LoadConfigFromYaml(const char* filename){
|
||||
}
|
||||
|
||||
try{
|
||||
YAML::Node frontend_conf = config["frontend_conf"];
|
||||
this->asr_sample_rate = frontend_conf["fs"].as<int>();
|
||||
|
||||
YAML::Node lang_conf = config["lang"];
|
||||
if (lang_conf.IsDefined()){
|
||||
language = lang_conf.as<string>();
|
||||
@ -257,6 +260,9 @@ void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
|
||||
this->cif_threshold = predictor_conf["threshold"].as<double>();
|
||||
this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
|
||||
|
||||
this->asr_sample_rate = frontend_conf["fs"].as<int>();
|
||||
|
||||
|
||||
}catch(exception const &e){
|
||||
LOG(ERROR) << "Error when load argument from vad config YAML.";
|
||||
exit(-1);
|
||||
@ -300,10 +306,18 @@ void Paraformer::InitSegDict(const std::string &seg_dict_model) {
|
||||
|
||||
Paraformer::~Paraformer()
|
||||
{
|
||||
if(vocab)
|
||||
if(vocab){
|
||||
delete vocab;
|
||||
if(seg_dict)
|
||||
}
|
||||
if(lm_vocab){
|
||||
delete lm_vocab;
|
||||
}
|
||||
if(seg_dict){
|
||||
delete seg_dict;
|
||||
}
|
||||
if(phone_set_){
|
||||
delete phone_set_;
|
||||
}
|
||||
}
|
||||
|
||||
void Paraformer::StartUtterance()
|
||||
@ -454,7 +468,7 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
|
||||
std::vector<std::vector<float>> asr_feats;
|
||||
FbankKaldi(MODEL_SAMPLE_RATE, din, len, asr_feats);
|
||||
FbankKaldi(asr_sample_rate, din, len, asr_feats);
|
||||
if(asr_feats.size() == 0){
|
||||
return "";
|
||||
}
|
||||
@ -675,6 +689,11 @@ Vocab* Paraformer::GetVocab()
|
||||
return vocab;
|
||||
}
|
||||
|
||||
Vocab* Paraformer::GetLmVocab()
|
||||
{
|
||||
return lm_vocab;
|
||||
}
|
||||
|
||||
PhoneSet* Paraformer::GetPhoneSet()
|
||||
{
|
||||
return phone_set_;
|
||||
|
||||
@ -20,6 +20,7 @@ namespace funasr {
|
||||
*/
|
||||
private:
|
||||
Vocab* vocab = nullptr;
|
||||
Vocab* lm_vocab = nullptr;
|
||||
SegDict* seg_dict = nullptr;
|
||||
PhoneSet* phone_set_ = nullptr;
|
||||
//const float scale = 22.6274169979695;
|
||||
@ -57,14 +58,15 @@ namespace funasr {
|
||||
|
||||
string Rescoring();
|
||||
string GetLang(){return language;};
|
||||
|
||||
int GetAsrSampleRate() { return asr_sample_rate; };
|
||||
void StartUtterance();
|
||||
void EndUtterance();
|
||||
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file);
|
||||
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
|
||||
string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
|
||||
string FinalizeDecode(WfstDecoder* &wfst_decoder,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
Vocab* GetVocab();
|
||||
Vocab* GetLmVocab();
|
||||
PhoneSet* GetPhoneSet();
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
@ -107,8 +109,7 @@ namespace funasr {
|
||||
int fsmn_dims = 512;
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
|
||||
|
||||
int asr_sample_rate = MODEL_SAMPLE_RATE;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -66,6 +66,20 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
|
||||
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Lm resource
|
||||
if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
|
||||
string fst_path, lm_config_path, lex_path;
|
||||
fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
|
||||
lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
|
||||
lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
|
||||
if (access(lex_path.c_str(), F_OK) != 0 )
|
||||
{
|
||||
LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
|
||||
}else{
|
||||
asr_handle->InitLm(fst_path, lm_config_path, lex_path);
|
||||
}
|
||||
}
|
||||
|
||||
// PUNC model
|
||||
if(model_path.find(PUNC_DIR) != model_path.end()){
|
||||
|
||||
@ -247,6 +247,395 @@ void SplitChiEngCharacters(const std::string &input_str,
|
||||
}
|
||||
}
|
||||
|
||||
// Timestamp Smooth
|
||||
void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
|
||||
if(!TimestampIsPunctuation(str_word)){
|
||||
alignment_str1.push_front(str_word);
|
||||
}
|
||||
}
|
||||
|
||||
bool TimestampIsPunctuation(const std::string& str) {
|
||||
const std::string punctuation = u8",。?、,?";
|
||||
// const std::string punctuation = u8",。?、,.?";
|
||||
for (char ch : str) {
|
||||
if (punctuation.find(ch) == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<vector<int>> ParseTimestamps(const std::string& str) {
|
||||
vector<vector<int>> timestamps;
|
||||
std::istringstream ss(str);
|
||||
std::string segment;
|
||||
|
||||
// skip first'['
|
||||
ss.ignore(1);
|
||||
|
||||
while (std::getline(ss, segment, ']')) {
|
||||
std::istringstream segmentStream(segment);
|
||||
std::string number;
|
||||
vector<int> ts;
|
||||
|
||||
// skip'['
|
||||
segmentStream.ignore(1);
|
||||
|
||||
while (std::getline(segmentStream, number, ',')) {
|
||||
ts.push_back(std::stoi(number));
|
||||
}
|
||||
if(ts.size() != 2){
|
||||
LOG(ERROR) << "ParseTimestamps Failed";
|
||||
timestamps.clear();
|
||||
return timestamps;
|
||||
}
|
||||
timestamps.push_back(ts);
|
||||
ss.ignore(1);
|
||||
}
|
||||
|
||||
return timestamps;
|
||||
}
|
||||
|
||||
bool TimestampIsDigit(U16CHAR_T &u16) {
|
||||
return u16 >= L'0' && u16 <= L'9';
|
||||
}
|
||||
|
||||
bool TimestampIsAlpha(U16CHAR_T &u16) {
|
||||
return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z');
|
||||
}
|
||||
|
||||
bool TimestampIsPunctuation(U16CHAR_T &u16) {
|
||||
// (& ' -) in the dict
|
||||
if (u16 == 0x26 || u16 == 0x27 || u16 == 0x2D){
|
||||
return false;
|
||||
}
|
||||
return (u16 >= 0x21 && u16 <= 0x2F) // 标准ASCII标点
|
||||
|| (u16 >= 0x3A && u16 <= 0x40) // 标准ASCII标点
|
||||
|| (u16 >= 0x5B && u16 <= 0x60) // 标准ASCII标点
|
||||
|| (u16 >= 0x7B && u16 <= 0x7E) // 标准ASCII标点
|
||||
|| (u16 >= 0x2000 && u16 <= 0x206F) // 常用的Unicode标点
|
||||
|| (u16 >= 0x3000 && u16 <= 0x303F); // CJK符号和标点
|
||||
}
|
||||
|
||||
void TimestampSplitChiEngCharacters(const std::string &input_str,
|
||||
std::vector<std::string> &characters) {
|
||||
characters.resize(0);
|
||||
std::string eng_word = "";
|
||||
U16CHAR_T space = 0x0020;
|
||||
std::vector<U16CHAR_T> u16_buf;
|
||||
u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
|
||||
U16CHAR_T* pu16 = u16_buf.data();
|
||||
U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
|
||||
size_t ilen = input_str.size();
|
||||
size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
if (EncodeConverter::IsChineseCharacter(pu16[i])) {
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
U8CHAR_T u8buf[4];
|
||||
size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
|
||||
u8buf[n] = '\0';
|
||||
characters.push_back((const char*)u8buf);
|
||||
} else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
U8CHAR_T u8buf[4];
|
||||
size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
|
||||
u8buf[n] = '\0';
|
||||
characters.push_back((const char*)u8buf);
|
||||
} else if (pu16[i] == space){
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
}else{
|
||||
U8CHAR_T u8buf[4];
|
||||
size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
|
||||
u8buf[n] = '\0';
|
||||
eng_word += (const char*)u8buf;
|
||||
}
|
||||
}
|
||||
if(!eng_word.empty()){
|
||||
characters.push_back(eng_word);
|
||||
eng_word = "";
|
||||
}
|
||||
}
|
||||
|
||||
std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty) {
|
||||
if(vec.size() == 0){
|
||||
if(out_empty){
|
||||
return "";
|
||||
}else{
|
||||
return "[]";
|
||||
}
|
||||
}
|
||||
std::ostringstream out;
|
||||
out << "[";
|
||||
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
out << "[";
|
||||
for (size_t j = 0; j < vec[i].size(); ++j) {
|
||||
out << vec[i][j];
|
||||
if (j < vec[i].size() - 1) {
|
||||
out << ",";
|
||||
}
|
||||
}
|
||||
out << "]";
|
||||
if (i < vec.size() - 1) {
|
||||
out << ",";
|
||||
}
|
||||
}
|
||||
|
||||
out << "]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){
|
||||
vector<vector<int>> timestamps_out;
|
||||
std::string timestamps_str = "";
|
||||
// process string to vector<string>
|
||||
std::vector<std::string> characters;
|
||||
funasr::TimestampSplitChiEngCharacters(text, characters);
|
||||
|
||||
std::vector<std::string> characters_itn;
|
||||
funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn);
|
||||
|
||||
//convert string to vector<vector<int>>
|
||||
vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
|
||||
|
||||
if (timestamps.size() == 0){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero";
|
||||
return timestamps_str;
|
||||
}
|
||||
|
||||
// edit distance
|
||||
int m = characters.size();
|
||||
int n = characters_itn.size();
|
||||
std::vector<std::vector<int>> dp(m + 1, std::vector<int>(n + 1, 0));
|
||||
|
||||
// init
|
||||
for (int i = 0; i <= m; ++i) {
|
||||
dp[i][0] = i;
|
||||
}
|
||||
for (int j = 0; j <= n; ++j) {
|
||||
dp[0][j] = j;
|
||||
}
|
||||
|
||||
// dp
|
||||
for (int i = 1; i <= m; ++i) {
|
||||
for (int j = 1; j <= n; ++j) {
|
||||
if (characters[i - 1] == characters_itn[j - 1]) {
|
||||
dp[i][j] = dp[i - 1][j - 1];
|
||||
} else {
|
||||
dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backtrack
|
||||
std::deque<string> alignment_str1, alignment_str2;
|
||||
int i = m, j = n;
|
||||
while (i > 0 || j > 0) {
|
||||
if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) {
|
||||
funasr::TimestampAdd(alignment_str1, characters[i - 1]);
|
||||
funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
} else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) {
|
||||
funasr::TimestampAdd(alignment_str1, characters[i - 1]);
|
||||
alignment_str2.push_front("");
|
||||
i -= 1;
|
||||
} else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) {
|
||||
alignment_str1.push_front("");
|
||||
funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
|
||||
j -= 1;
|
||||
} else{
|
||||
funasr::TimestampAdd(alignment_str1, characters[i - 1]);
|
||||
funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
|
||||
i -= 1;
|
||||
j -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// smooth
|
||||
int itn_count = 0;
|
||||
int idx_tp = 0;
|
||||
int idx_itn = 0;
|
||||
vector<vector<int>> timestamps_tmp;
|
||||
for(int index = 0; index < alignment_str1.size(); index++){
|
||||
if (alignment_str1[index] == alignment_str2[index]){
|
||||
bool subsidy = false;
|
||||
if (itn_count > 0 && timestamps_tmp.size() == 0){
|
||||
if(idx_tp >= timestamps.size()){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
|
||||
return timestamps_str;
|
||||
}
|
||||
timestamps_tmp.push_back(timestamps[idx_tp]);
|
||||
subsidy = true;
|
||||
itn_count++;
|
||||
}
|
||||
|
||||
if (timestamps_tmp.size() > 0){
|
||||
if (itn_count > 0){
|
||||
int begin = timestamps_tmp[0][0];
|
||||
int end = timestamps_tmp.back()[1];
|
||||
int total_time = end - begin;
|
||||
int interval = total_time / itn_count;
|
||||
for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
|
||||
vector<int> ts;
|
||||
ts.push_back(begin + interval*idx_cnt);
|
||||
if(idx_cnt == itn_count-1){
|
||||
ts.push_back(end);
|
||||
}else {
|
||||
ts.push_back(begin + interval*(idx_cnt + 1));
|
||||
}
|
||||
timestamps_out.push_back(ts);
|
||||
}
|
||||
}
|
||||
timestamps_tmp.clear();
|
||||
}
|
||||
if(!subsidy){
|
||||
if(idx_tp >= timestamps.size()){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
|
||||
return timestamps_str;
|
||||
}
|
||||
timestamps_out.push_back(timestamps[idx_tp]);
|
||||
}
|
||||
idx_tp++;
|
||||
itn_count = 0;
|
||||
}else{
|
||||
if (!alignment_str1[index].empty()){
|
||||
if(idx_tp >= timestamps.size()){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
|
||||
return timestamps_str;
|
||||
}
|
||||
timestamps_tmp.push_back(timestamps[idx_tp]);
|
||||
idx_tp++;
|
||||
}
|
||||
if (!alignment_str2[index].empty()){
|
||||
itn_count++;
|
||||
}
|
||||
}
|
||||
// count length of itn
|
||||
if (!alignment_str2[index].empty()){
|
||||
idx_itn++;
|
||||
}
|
||||
}
|
||||
{
|
||||
if (itn_count > 0 && timestamps_tmp.size() == 0){
|
||||
if (timestamps_out.size() > 0){
|
||||
timestamps_tmp.push_back(timestamps_out.back());
|
||||
itn_count++;
|
||||
timestamps_out.pop_back();
|
||||
} else{
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp.";
|
||||
return timestamps_str;
|
||||
}
|
||||
}
|
||||
|
||||
if (timestamps_tmp.size() > 0){
|
||||
if (itn_count > 0){
|
||||
int begin = timestamps_tmp[0][0];
|
||||
int end = timestamps_tmp.back()[1];
|
||||
int total_time = end - begin;
|
||||
int interval = total_time / itn_count;
|
||||
for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
|
||||
vector<int> ts;
|
||||
ts.push_back(begin + interval*idx_cnt);
|
||||
if(idx_cnt == itn_count-1){
|
||||
ts.push_back(end);
|
||||
}else {
|
||||
ts.push_back(begin + interval*(idx_cnt + 1));
|
||||
}
|
||||
timestamps_out.push_back(ts);
|
||||
}
|
||||
}
|
||||
timestamps_tmp.clear();
|
||||
}
|
||||
}
|
||||
if(timestamps_out.size() != idx_itn){
|
||||
LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched.";
|
||||
return timestamps_str;
|
||||
}
|
||||
|
||||
timestamps_str = VectorToString(timestamps_out);
|
||||
return timestamps_str;
|
||||
}
|
||||
|
||||
std::string TimestampSentence(std::string &text, std::string &str_time){
|
||||
std::vector<std::string> characters;
|
||||
funasr::TimestampSplitChiEngCharacters(text, characters);
|
||||
vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
|
||||
|
||||
int idx_str = 0, idx_ts = 0;
|
||||
int start = -1, end = -1;
|
||||
std::string text_seg = "";
|
||||
std::string ts_sentences = "";
|
||||
std::string ts_sent = "";
|
||||
vector<vector<int>> ts_seg;
|
||||
while(idx_str < characters.size()){
|
||||
if (TimestampIsPunctuation(characters[idx_str])){
|
||||
if(ts_seg.size() >0){
|
||||
if (ts_seg[0].size() == 2){
|
||||
start = ts_seg[0][0];
|
||||
}
|
||||
if (ts_seg[ts_seg.size()-1].size() == 2){
|
||||
end = ts_seg[ts_seg.size()-1][1];
|
||||
}
|
||||
}
|
||||
// format
|
||||
ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
|
||||
ts_sent += "\"punc\":\"" + characters[idx_str] + "\",";
|
||||
ts_sent += "\"start\":\"" + to_string(start) + "\",";
|
||||
ts_sent += "\"end\":\"" + to_string(end) + "\",";
|
||||
ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
|
||||
|
||||
if (idx_str == characters.size()-1){
|
||||
ts_sentences += ts_sent;
|
||||
} else{
|
||||
ts_sentences += ts_sent + ",";
|
||||
}
|
||||
// clear
|
||||
text_seg = "";
|
||||
ts_sent = "";
|
||||
start = 0;
|
||||
end = 0;
|
||||
ts_seg.clear();
|
||||
} else if(idx_ts < timestamps.size()) {
|
||||
if (text_seg.empty()){
|
||||
text_seg = characters[idx_str];
|
||||
}else{
|
||||
text_seg += " " + characters[idx_str];
|
||||
}
|
||||
ts_seg.push_back(timestamps[idx_ts]);
|
||||
idx_ts++;
|
||||
}
|
||||
idx_str++;
|
||||
}
|
||||
// for none punc results
|
||||
if(ts_seg.size() >0){
|
||||
if (ts_seg[0].size() == 2){
|
||||
start = ts_seg[0][0];
|
||||
}
|
||||
if (ts_seg[ts_seg.size()-1].size() == 2){
|
||||
end = ts_seg[ts_seg.size()-1][1];
|
||||
}
|
||||
// format
|
||||
ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
|
||||
ts_sent += "\"punc\":\"\",";
|
||||
ts_sent += "\"start\":\"" + to_string(start) + "\",";
|
||||
ts_sent += "\"end\":\"" + to_string(end) + "\",";
|
||||
ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
|
||||
ts_sentences += ts_sent;
|
||||
}
|
||||
|
||||
return "[" +ts_sentences + "]";
|
||||
}
|
||||
|
||||
std::vector<std::string> split(const std::string &s, char delim) {
|
||||
std::vector<std::string> elems;
|
||||
std::stringstream ss(s);
|
||||
@ -333,12 +722,23 @@ string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>
|
||||
int sub_word = !(word.find("@@") == string::npos);
|
||||
// process word start and middle part
|
||||
if (sub_word) {
|
||||
combine += word.erase(word.length() - 2);
|
||||
if(!is_combining){
|
||||
begin = timestamp_list[i][0];
|
||||
// if badcase: lo@@ chinese
|
||||
if (i == raw_char.size()-1 || i<raw_char.size()-1 && IsChinese(raw_char[i+1])){
|
||||
word = word.erase(word.length() - 2) + " ";
|
||||
if (is_combining) {
|
||||
combine += word;
|
||||
is_combining = false;
|
||||
word = combine;
|
||||
combine = "";
|
||||
}
|
||||
}else{
|
||||
combine += word.erase(word.length() - 2);
|
||||
if(!is_combining){
|
||||
begin = timestamp_list[i][0];
|
||||
}
|
||||
is_combining = true;
|
||||
continue;
|
||||
}
|
||||
is_combining = true;
|
||||
continue;
|
||||
}
|
||||
// process word end part
|
||||
else if (is_combining) {
|
||||
@ -669,4 +1069,9 @@ void ExtractHws(string hws_file, unordered_map<string, int> &hws_map, string& nn
|
||||
ifs_hws.close();
|
||||
}
|
||||
|
||||
void SmoothTimestamps(std::string &str_punc, std::string &str_itn, std::string &str_timetamp){
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -3,11 +3,13 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include "tensor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace funasr {
|
||||
typedef unsigned short U16CHAR_T;
|
||||
extern float *LoadParams(const char *filename);
|
||||
|
||||
extern void SaveDataFile(const char *filename, void *data, uint32_t len);
|
||||
@ -35,7 +37,17 @@ void KeepChineseCharacterAndSplit(const std::string &input_str,
|
||||
std::vector<std::string> &chinese_characters);
|
||||
void SplitChiEngCharacters(const std::string &input_str,
|
||||
std::vector<std::string> &characters);
|
||||
|
||||
void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word);
|
||||
vector<vector<int>> ParseTimestamps(const std::string& str);
|
||||
bool TimestampIsDigit(U16CHAR_T &u16);
|
||||
bool TimestampIsAlpha(U16CHAR_T &u16);
|
||||
bool TimestampIsPunctuation(U16CHAR_T &u16);
|
||||
bool TimestampIsPunctuation(const std::string& str);
|
||||
void TimestampSplitChiEngCharacters(const std::string &input_str,
|
||||
std::vector<std::string> &characters);
|
||||
std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty=true);
|
||||
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
|
||||
std::string TimestampSentence(std::string &text, std::string &str_time);
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
template<typename T>
|
||||
|
||||
@ -16,6 +16,12 @@ Vocab::Vocab(const char *filename)
|
||||
ifstream in(filename);
|
||||
LoadVocabFromYaml(filename);
|
||||
}
|
||||
Vocab::Vocab(const char *filename, const char *lex_file)
|
||||
{
|
||||
ifstream in(filename);
|
||||
LoadVocabFromYaml(filename);
|
||||
LoadLex(lex_file);
|
||||
}
|
||||
Vocab::~Vocab()
|
||||
{
|
||||
}
|
||||
@ -37,11 +43,37 @@ void Vocab::LoadVocabFromYaml(const char* filename){
|
||||
}
|
||||
}
|
||||
|
||||
int Vocab::GetIdByToken(const std::string &token) {
|
||||
if (token_id.count(token)) {
|
||||
return token_id[token];
|
||||
void Vocab::LoadLex(const char* filename){
|
||||
std::ifstream file(filename);
|
||||
std::string line;
|
||||
while (std::getline(file, line)) {
|
||||
std::string key, value;
|
||||
std::istringstream iss(line);
|
||||
std::getline(iss, key, '\t');
|
||||
std::getline(iss, value);
|
||||
|
||||
if (!key.empty() && !value.empty()) {
|
||||
lex_map[key] = value;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
|
||||
file.close();
|
||||
}
|
||||
|
||||
string Vocab::Word2Lex(const std::string &word) const {
|
||||
auto it = lex_map.find(word);
|
||||
if (it != lex_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
int Vocab::GetIdByToken(const std::string &token) const {
|
||||
auto it = token_id.find(token);
|
||||
if (it != token_id.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
|
||||
@ -120,8 +152,8 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
std::string combine = "";
|
||||
std::string unicodeChar = "▁";
|
||||
|
||||
for (auto it = in.begin(); it != in.end(); it++) {
|
||||
string word = vocab[*it];
|
||||
for (i=0; i<in.size(); i++){
|
||||
string word = vocab[in[i]];
|
||||
// step1 space character skips
|
||||
if (word == "<s>" || word == "</s>" || word == "<unk>")
|
||||
continue;
|
||||
@ -146,9 +178,20 @@ string Vocab::Vector2StringV2(vector<int> in, std::string language)
|
||||
int sub_word = !(word.find("@@") == string::npos);
|
||||
// process word start and middle part
|
||||
if (sub_word) {
|
||||
combine += word.erase(word.length() - 2);
|
||||
is_combining = true;
|
||||
continue;
|
||||
// if badcase: lo@@ chinese
|
||||
if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
|
||||
word = word.erase(word.length() - 2) + " ";
|
||||
if (is_combining) {
|
||||
combine += word;
|
||||
is_combining = false;
|
||||
word = combine;
|
||||
combine = "";
|
||||
}
|
||||
}else{
|
||||
combine += word.erase(word.length() - 2);
|
||||
is_combining = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// process word end part
|
||||
else if (is_combining) {
|
||||
|
||||
@ -13,11 +13,14 @@ class Vocab {
|
||||
private:
|
||||
vector<string> vocab;
|
||||
std::map<string, int> token_id;
|
||||
std::map<string, string> lex_map;
|
||||
bool IsEnglish(string ch);
|
||||
void LoadVocabFromYaml(const char* filename);
|
||||
void LoadLex(const char* filename);
|
||||
|
||||
public:
|
||||
Vocab(const char *filename);
|
||||
Vocab(const char *filename, const char *lex_file);
|
||||
~Vocab();
|
||||
int Size() const;
|
||||
bool IsChinese(string ch);
|
||||
@ -26,7 +29,8 @@ class Vocab {
|
||||
string Vector2StringV2(vector<int> in, std::string language="");
|
||||
string Id2String(int id) const;
|
||||
string WordFormat(std::string word);
|
||||
int GetIdByToken(const std::string &token);
|
||||
int GetIdByToken(const std::string &token) const;
|
||||
string Word2Lex(const std::string &word) const;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -41,6 +41,10 @@ parser.add_argument("--audio_in",
|
||||
type=str,
|
||||
default=None,
|
||||
help="audio_in")
|
||||
parser.add_argument("--audio_fs",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="audio_fs")
|
||||
parser.add_argument("--send_without_sleep",
|
||||
action="store_true",
|
||||
default=True,
|
||||
@ -164,7 +168,7 @@ async def record_from_scp(chunk_begin, chunk_size):
|
||||
hotword_msg=json.dumps(fst_dict)
|
||||
print (hotword_msg)
|
||||
|
||||
sample_rate = 16000
|
||||
sample_rate = args.audio_fs
|
||||
wav_format = "pcm"
|
||||
use_itn=True
|
||||
if args.use_itn == 0:
|
||||
@ -182,20 +186,12 @@ async def record_from_scp(chunk_begin, chunk_size):
|
||||
if wav_path.endswith(".pcm"):
|
||||
with open(wav_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
elif wav_path.endswith(".wav"):
|
||||
import wave
|
||||
with wave.open(wav_path, "rb") as wav_file:
|
||||
params = wav_file.getparams()
|
||||
sample_rate = wav_file.getframerate()
|
||||
frames = wav_file.readframes(wav_file.getnframes())
|
||||
audio_bytes = bytes(frames)
|
||||
else:
|
||||
wav_format = "others"
|
||||
with open(wav_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# stride = int(args.chunk_size/1000*16000*2)
|
||||
stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * 16000 * 2)
|
||||
stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * sample_rate * 2)
|
||||
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||
# print(stride)
|
||||
|
||||
@ -253,6 +249,7 @@ async def message(id):
|
||||
wav_name = meg.get("wav_name", "demo")
|
||||
text = meg["text"]
|
||||
timestamp=""
|
||||
offline_msg_done = meg.get("is_final", False)
|
||||
if "timestamp" in meg:
|
||||
timestamp = meg["timestamp"]
|
||||
|
||||
@ -262,7 +259,9 @@ async def message(id):
|
||||
else:
|
||||
text_write_line = "{}\t{}\n".format(wav_name, text)
|
||||
ibest_writer.write(text_write_line)
|
||||
|
||||
|
||||
if 'mode' not in meg:
|
||||
continue
|
||||
if meg["mode"] == "online":
|
||||
text_print += "{}".format(text)
|
||||
text_print = text_print[-args.words_max_print:]
|
||||
@ -289,7 +288,7 @@ async def message(id):
|
||||
text_print = text_print[-args.words_max_print:]
|
||||
os.system('clear')
|
||||
print("\rpid" + str(id) + ": " + text_print)
|
||||
offline_msg_done=True
|
||||
# offline_msg_done=True
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
|
||||
@ -17,6 +17,7 @@ Currently, the FunASR runtime-SDK supports the deployment of file transcription
|
||||
To meet the needs of different users, we have prepared different tutorials with text and images for both novice and advanced developers.
|
||||
|
||||
### Whats-new
|
||||
- 2024/01/03: Fixed known crash issues as well as memory leak problems, docker image version funasr-runtime-sdk-en-cpu-0.1.2 (0cdd9f4a4bb5).
|
||||
- 2023/11/08: Adaptation to runtime structure changes (FunASR/funasr/runtime -> FunASR/runtime), docker image version funasr-runtime-sdk-en-cpu-0.1.1 (27017f70f72a).
|
||||
- 2023/10/16: English File Transcription Service 1.0 released, docker image version funasr-runtime-sdk-en-cpu-0.1.0 (e0de03eb0163), refer to the detailed documentation([here](https://mp.weixin.qq.com/s/DZZUTj-6xwFfi-96ml--4A))
|
||||
|
||||
@ -39,6 +40,7 @@ The FunASR real-time speech-to-text service software package not only performs r
|
||||
In order to meet the needs of different users for different scenarios, different tutorials are prepared:
|
||||
|
||||
### Whats-new
|
||||
- 2024/01/03: Real-time Transcription Service 1.6 released,The 2pass-offline mode supports Ngram language model decoding and WFST hotwords, while also addressing known crash issues and memory leak problems, docker image version funasr-runtime-sdk-online-cpu-0.1.6 (f99925110d27)
|
||||
- 2023/11/09: Real-time Transcription Service 1.5 released,fix bug: without online results, docker image version funasr-runtime-sdk-online-cpu-0.1.5 (b16584b6d38b)
|
||||
- 2023/11/08: Real-time Transcription Service 1.4 released, supporting server-side loading of hotwords (updated hotword communication protocol), adaptation to runtime structure changes (FunASR/funasr/runtime -> FunASR/runtime), docker image version funasr-runtime-sdk-online-cpu-0.1.4(691974017c38).
|
||||
- 2023/09/19: Real-time Transcription Service 1.2 released, supporting hotwords, timestamps, and ITN model in 2pass mode, docker image version funasr-runtime-sdk-online-cpu-0.1.2 (7222c5319bcf).
|
||||
@ -66,10 +68,12 @@ Currently, the FunASR runtime-SDK supports the deployment of file transcription
|
||||
To meet the needs of different users, we have prepared different tutorials with text and images for both novice and advanced developers.
|
||||
|
||||
### Whats-new
|
||||
2023/11/08: File Transcription Service 3.0 released, supporting punctuation large model, Ngram model, fst hotwords (updated hotword communication protocol), server-side loading of hotwords, adaptation to runtime structure changes (FunASR/funasr/runtime -> FunASR/runtime), docker image version funasr-runtime-sdk-cpu-0.3.0 (caa64bddbb43), refer to the detailed documentation ([here]())
|
||||
2023/09/19: File Transcription Service 2.2 released, supporting ITN model, docker image version funasr-runtime-sdk-cpu-0.2.2 (2c5286be13e9).
|
||||
2023/08/22: File Transcription Service 2.0 released, integrated ffmpeg to support various audio and video inputs, supporting hotword model and timestamp model, docker image version funasr-runtime-sdk-cpu-0.2.0 (1ad3d19e0707), refer to the detailed documentation ([here](https://mp.weixin.qq.com/s/oJHe0MKDqTeuIFH-F7GHMg))
|
||||
2023/07/03: File Transcription Service 1.0 released, docker image version funasr-runtime-sdk-cpu-0.1.0 (1ad3d19e0707), refer to the detailed documentation ([here](https://mp.weixin.qq.com/s/DHQwbgdBWcda0w_L60iUww))
|
||||
- 2024/01/08: File Transcription Service 4.1 released, optimized format sentence-level timestamps, docker image version funasr-runtime-sdk-cpu-0.4.1 (0250f8ef981b)
|
||||
- 2024/01/03: File Transcription Service 4.0 released, Added support for 8k models, optimized timestamp mismatch issues and added sentence-level timestamps, improved the effectiveness of English word FST hotwords, supported automated configuration of thread parameters, and fixed known crash issues as well as memory leak problems, docker image version funasr-runtime-sdk-cpu-0.4.0 (c4483ee08f04)
|
||||
- 2023/11/08: File Transcription Service 3.0 released, supporting punctuation large model, Ngram model, fst hotwords (updated hotword communication protocol), server-side loading of hotwords, adaptation to runtime structure changes (FunASR/funasr/runtime -> FunASR/runtime), docker image version funasr-runtime-sdk-cpu-0.3.0 (caa64bddbb43), refer to the detailed documentation ([here]())
|
||||
- 2023/09/19: File Transcription Service 2.2 released, supporting ITN model, docker image version funasr-runtime-sdk-cpu-0.2.2 (2c5286be13e9).
|
||||
- 2023/08/22: File Transcription Service 2.0 released, integrated ffmpeg to support various audio and video inputs, supporting hotword model and timestamp model, docker image version funasr-runtime-sdk-cpu-0.2.0 (1ad3d19e0707), refer to the detailed documentation ([here](https://mp.weixin.qq.com/s/oJHe0MKDqTeuIFH-F7GHMg))
|
||||
- 2023/07/03: File Transcription Service 1.0 released, docker image version funasr-runtime-sdk-cpu-0.1.0 (1ad3d19e0707), refer to the detailed documentation ([here](https://mp.weixin.qq.com/s/DHQwbgdBWcda0w_L60iUww))
|
||||
|
||||
### Technical Principles
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ FunASR是由阿里巴巴通义实验室语音团队开源的一款语音识别
|
||||
为了支持不同用户的需求,针对不同场景,准备了不同的图文教程:
|
||||
|
||||
### 最新动态
|
||||
- 2024/01/03: 英文离线文件转写服务 1.2 发布,修复已知的crash问题及内存泄漏问题,dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.2 (0cdd9f4a4bb5)
|
||||
- 2023/11/08: 英文离线文件转写服务 1.1 发布,runtime结构变化适配(FunASR/funasr/runtime->FunASR/runtime),dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.1 (27017f70f72a)
|
||||
- 2023/10/16: 英文离线文件转写服务 1.0 发布,dokcer镜像版本funasr-runtime-sdk-en-cpu-0.1.0 (e0de03eb0163),原理介绍文档([点击此处](https://mp.weixin.qq.com/s/DZZUTj-6xwFfi-96ml--4A))
|
||||
|
||||
@ -33,6 +34,7 @@ FunASR实时语音听写服务软件包,既可以实时地进行语音转文
|
||||
为了支持不同用户的需求,针对不同场景,准备了不同的图文教程:
|
||||
|
||||
### 最新动态
|
||||
- 2024/01/03: 中文实时语音听写服务 1.6 发布,2pass-offline模式支持Ngram语言模型解码、wfst热词,同时修复已知的crash问题及内存泄漏问题,dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.6 (f99925110d27)
|
||||
- 2023/11/09: 中文实时语音听写服务 1.5 发布,修复无实时结果的问题,dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.5 (b16584b6d38b)
|
||||
- 2023/11/08: 中文实时语音听写服务 1.4 发布,支持服务端加载热词(更新热词通信协议)、runtime结构变化适配(FunASR/funasr/runtime->FunASR/runtime),dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.4 (691974017c38)
|
||||
- 2023/09/19: 中文实时语音听写服务 1.2 发布,2pass模式支持热词、时间戳、ITN模型,dokcer镜像版本funasr-runtime-sdk-online-cpu-0.1.2 (7222c5319bcf)
|
||||
@ -52,7 +54,8 @@ FunASR实时语音听写服务软件包,既可以实时地进行语音转文
|
||||
为了支持不同用户的需求,针对不同场景,准备了不同的图文教程:
|
||||
|
||||
### 最新动态
|
||||
|
||||
- 2024/01/08: 中文离线文件转写服务 4.1 发布,优化句子级时间戳json格式,dokcer镜像版本funasr-runtime-sdk-cpu-0.4.1 (0250f8ef981b)
|
||||
- 2024/01/03: 中文离线文件转写服务 4.0 发布,新增支持8k模型、优化时间戳不匹配问题及增加句子级别时间戳、优化英文单词fst热词效果、支持自动化配置线程参数,同时修复已知的crash问题及内存泄漏问题,dokcer镜像版本funasr-runtime-sdk-cpu-0.4.0 (c4483ee08f04)
|
||||
- 2023/11/08: 中文离线文件转写服务 3.0 发布,支持标点大模型、支持Ngram模型、支持fst热词(更新热词通信协议)、支持服务端加载热词、runtime结构变化适配(FunASR/funasr/runtime->FunASR/runtime),dokcer镜像版本funasr-runtime-sdk-cpu-0.3.0 (caa64bddbb43),原理介绍文档([点击此处](https://mp.weixin.qq.com/s/jSbnKw_m31BUUbTukPSOIw))
|
||||
- 2023/09/19: 中文离线文件转写服务 2.2 发布,支持ITN模型,dokcer镜像版本funasr-runtime-sdk-cpu-0.2.2 (2c5286be13e9)
|
||||
- 2023/08/22: 中文离线文件转写服务 2.0 发布,集成ffmpeg支持多种音视频输入、支持热词模型、支持时间戳模型,dokcer镜像版本funasr-runtime-sdk-cpu-0.2.0 (1ad3d19e0707),原理介绍文档([点击此处](https://mp.weixin.qq.com/s/oJHe0MKDqTeuIFH-F7GHMg))
|
||||
|
||||
@ -5,6 +5,7 @@ online_model_dir="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab840
|
||||
vad_dir="damo/speech_fsmn_vad_zh-cn-16k-common-onnx"
|
||||
punc_dir="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx"
|
||||
itn_dir="thuduj12/fst_itn_zh"
|
||||
lm_dir="damo/speech_ngram_lm_zh-cn-ai-wesp-fst"
|
||||
port=10095
|
||||
certfile="../../../ssl_key/server.crt"
|
||||
keyfile="../../../ssl_key/server.key"
|
||||
@ -30,6 +31,7 @@ cd /workspace/FunASR/runtime/websocket/build/bin
|
||||
--vad-dir "${vad_dir}" \
|
||||
--punc-dir "${punc_dir}" \
|
||||
--itn-dir "${itn_dir}" \
|
||||
--lm-dir "${lm_dir}" \
|
||||
--decoder-thread-num ${decoder_thread_num} \
|
||||
--model-thread-num ${model_thread_num} \
|
||||
--io-thread-num ${io_thread_num} \
|
||||
|
||||
64
runtime/triton_gpu/README_ONLINE.md
Executable file
64
runtime/triton_gpu/README_ONLINE.md
Executable file
@ -0,0 +1,64 @@
|
||||
### Steps:
|
||||
1. Prepare model repo files
|
||||
* git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx.git
|
||||
* Convert lfr_cmvn_pe.onnx model. For example: python export_lfr_cmvn_pe_onnx.py
|
||||
* If you export to onnx, you should have several model files in `${MODEL_DIR}`:
|
||||
```
|
||||
├── README.md
|
||||
└── model_repo_paraformer_large_online
|
||||
├── cif_search
|
||||
│ ├── 1
|
||||
│ │ └── model.py
|
||||
│ └── config.pbtxt
|
||||
├── decoder
|
||||
│ ├── 1
|
||||
│ │ └── decoder.onnx
|
||||
│ └── config.pbtxt
|
||||
├── encoder
|
||||
│ ├── 1
|
||||
│ │ └── model.onnx
|
||||
│ └── config.pbtxt
|
||||
├── feature_extractor
|
||||
│ ├── 1
|
||||
│ │ └── model.py
|
||||
│ ├── config.pbtxt
|
||||
│ └── config.yaml
|
||||
├── lfr_cmvn_pe
|
||||
│ ├── 1
|
||||
│ │ └── lfr_cmvn_pe.onnx
|
||||
│ ├── am.mvn
|
||||
│ ├── config.pbtxt
|
||||
│ └── export_lfr_cmvn_pe_onnx.py
|
||||
└── streaming_paraformer
|
||||
├── 1
|
||||
└── config.pbtxt
|
||||
```
|
||||
|
||||
2. Follow below instructions to launch triton server
|
||||
```sh
|
||||
# using docker image Dockerfile/Dockerfile.server
|
||||
docker build . -f Dockerfile/Dockerfile.server -t triton-paraformer:23.01
|
||||
docker run -it --rm --name "paraformer_triton_server" --gpus all -v <path_host/model_repo_paraformer_large_online>:/workspace/ --shm-size 1g --net host triton-paraformer:23.01
|
||||
|
||||
# launch the service
|
||||
cd /workspace
|
||||
tritonserver --model-repository model_repo_paraformer_large_online \
|
||||
--pinned-memory-pool-byte-size=512000000 \
|
||||
--cuda-memory-pool-byte-size=0:1024000000
|
||||
|
||||
```
|
||||
|
||||
### Performance benchmark with a single A10
|
||||
|
||||
* FP32, onnx, [paraformer larger online](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx/summary
|
||||
),Our chunksize is 10 * 960 / 16000 = 0.6 s, so we should care about the perf of latency less than 0.6s so that it can be a realtime application.
|
||||
|
||||
|
||||
| Concurrency | Throughput | Latency_p50 (ms) | Latency_p90 (ms) | Latency_p95 (ms) | Latency_p99 (ms) |
|
||||
|-------------|------------|------------------|------------------|------------------|------------------|
|
||||
| 20 | 309.252 | 56.913 | 76.267 | 85.598 | 138.462 |
|
||||
| 40 | 391.058 | 97.911 | 145.509 | 150.545 | 185.399 |
|
||||
| 60 | 426.269 | 138.244 | 185.855 | 201.016 | 236.528 |
|
||||
| 80 | 431.781 | 170.991 | 227.983 | 252.453 | 412.273 |
|
||||
| 100 | 473.351 | 206.205 | 262.612 | 288.964 | 463.337 |
|
||||
|
||||
268
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/1/model.py
Executable file
268
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/1/model.py
Executable file
@ -0,0 +1,268 @@
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
import numpy as np
|
||||
from torch.utils.dlpack import from_dlpack
|
||||
import json
|
||||
import yaml
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LimitedDict(OrderedDict):
|
||||
def __init__(self, max_length):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if len(self) >= self.max_length:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class CIFSearch:
|
||||
"""CIFSearch: https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx
|
||||
/paraformer_online_bin.py """
|
||||
def __init__(self):
|
||||
self.cache = {"cif_hidden": np.zeros((1, 1, 512)).astype(np.float32),
|
||||
"cif_alphas": np.zeros((1, 1)).astype(np.float32), "last_chunk": False}
|
||||
self.chunk_size = [5, 10, 5]
|
||||
self.tail_threshold = 0.45
|
||||
self.cif_threshold = 1.0
|
||||
|
||||
def infer(self, hidden, alphas):
|
||||
batch_size, len_time, hidden_size = hidden.shape
|
||||
token_length = []
|
||||
list_fires = []
|
||||
list_frames = []
|
||||
cache_alphas = []
|
||||
cache_hiddens = []
|
||||
alphas[:, :self.chunk_size[0]] = 0.0
|
||||
alphas[:, sum(self.chunk_size[:2]):] = 0.0
|
||||
|
||||
if self.cache is not None and "cif_alphas" in self.cache and "cif_hidden" in self.cache:
|
||||
hidden = np.concatenate((self.cache["cif_hidden"], hidden), axis=1)
|
||||
alphas = np.concatenate((self.cache["cif_alphas"], alphas), axis=1)
|
||||
if self.cache is not None and "last_chunk" in self.cache and self.cache["last_chunk"]:
|
||||
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
|
||||
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
|
||||
tail_alphas = np.tile(tail_alphas, (batch_size, 1))
|
||||
hidden = np.concatenate((hidden, tail_hidden), axis=1)
|
||||
alphas = np.concatenate((alphas, tail_alphas), axis=1)
|
||||
|
||||
len_time = alphas.shape[1]
|
||||
for b in range(batch_size):
|
||||
integrate = 0.0
|
||||
frames = np.zeros(hidden_size).astype(np.float32)
|
||||
list_frame = []
|
||||
list_fire = []
|
||||
for t in range(len_time):
|
||||
alpha = alphas[b][t]
|
||||
if alpha + integrate < self.cif_threshold:
|
||||
integrate += alpha
|
||||
list_fire.append(integrate)
|
||||
frames += alpha * hidden[b][t]
|
||||
else:
|
||||
frames += (self.cif_threshold - integrate) * hidden[b][t]
|
||||
list_frame.append(frames)
|
||||
integrate += alpha
|
||||
list_fire.append(integrate)
|
||||
integrate -= self.cif_threshold
|
||||
frames = integrate * hidden[b][t]
|
||||
|
||||
cache_alphas.append(integrate)
|
||||
if integrate > 0.0:
|
||||
cache_hiddens.append(frames / integrate)
|
||||
else:
|
||||
cache_hiddens.append(frames)
|
||||
|
||||
token_length.append(len(list_frame))
|
||||
list_fires.append(list_fire)
|
||||
list_frames.append(list_frame)
|
||||
|
||||
max_token_len = max(token_length)
|
||||
list_ls = []
|
||||
for b in range(batch_size):
|
||||
pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
|
||||
if token_length[b] == 0:
|
||||
list_ls.append(pad_frames)
|
||||
else:
|
||||
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
|
||||
|
||||
self.cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
|
||||
self.cache["cif_alphas"] = np.expand_dims(self.cache["cif_alphas"], axis=0)
|
||||
self.cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
|
||||
self.cache["cif_hidden"] = np.expand_dims(self.cache["cif_hidden"], axis=0)
|
||||
|
||||
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Your Python model must use the same class name. Every Python model
|
||||
that is created must have "TritonPythonModel" as the class name.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""`initialize` is called only once when the model is being loaded.
|
||||
Implementing `initialize` function is optional. This function allows
|
||||
the model to initialize any state associated with this model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : dict
|
||||
Both keys and values are strings. The dictionary keys and values are:
|
||||
* model_config: A JSON string containing the model configuration
|
||||
* model_instance_kind: A string containing model instance kind
|
||||
* model_instance_device_id: A string containing model instance device ID
|
||||
* model_repository: Model repository path
|
||||
* model_version: Model version
|
||||
* model_name: Model name
|
||||
"""
|
||||
self.model_config = model_config = json.loads(args['model_config'])
|
||||
self.max_batch_size = max(model_config["max_batch_size"], 1)
|
||||
|
||||
# # Get OUTPUT0 configuration
|
||||
output0_config = pb_utils.get_output_config_by_name(
|
||||
model_config, "transcripts")
|
||||
# # Convert Triton types to numpy types
|
||||
self.out0_dtype = pb_utils.triton_string_to_numpy(
|
||||
output0_config['data_type'])
|
||||
|
||||
self.init_vocab(self.model_config['parameters'])
|
||||
|
||||
self.cif_search_cache = LimitedDict(1024)
|
||||
self.start = LimitedDict(1024)
|
||||
|
||||
def init_vocab(self, parameters):
|
||||
for li in parameters.items():
|
||||
key, value = li
|
||||
value = value["string_value"]
|
||||
if key == "vocabulary":
|
||||
self.vocab_dict = self.load_vocab(value)
|
||||
|
||||
def load_vocab(self, vocab_file):
|
||||
with open(str(vocab_file), 'rb') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
return config['token_list']
|
||||
|
||||
async def execute(self, requests):
|
||||
"""`execute` must be implemented in every Python model. `execute`
|
||||
function receives a list of pb_utils.InferenceRequest as the only
|
||||
argument. This function is called when an inference is requested
|
||||
for this model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
requests : list
|
||||
A list of pb_utils.InferenceRequest
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of pb_utils.InferenceResponse. The length of this list must
|
||||
be the same as `requests`
|
||||
"""
|
||||
# Every Python backend must iterate through list of requests and create
|
||||
# an instance of pb_utils.InferenceResponse class for each of them. You
|
||||
# should avoid storing any of the input Tensors in the class attributes
|
||||
# as they will be overridden in subsequent inference requests. You can
|
||||
# make a copy of the underlying NumPy array and store it if it is
|
||||
# required.
|
||||
|
||||
batch_end = []
|
||||
responses = []
|
||||
batch_corrid = []
|
||||
qualified_corrid = []
|
||||
batch_result = {}
|
||||
inference_response_awaits = []
|
||||
|
||||
for request in requests:
|
||||
hidden = pb_utils.get_input_tensor_by_name(request, "enc")
|
||||
hidden = from_dlpack(hidden.to_dlpack()).cpu().numpy()
|
||||
alphas = pb_utils.get_input_tensor_by_name(request, "alphas")
|
||||
alphas = from_dlpack(alphas.to_dlpack()).cpu().numpy()
|
||||
hidden_len = pb_utils.get_input_tensor_by_name(request, "enc_len")
|
||||
hidden_len = from_dlpack(hidden_len.to_dlpack()).cpu().numpy()
|
||||
|
||||
in_start = pb_utils.get_input_tensor_by_name(request, "START")
|
||||
start = in_start.as_numpy()[0][0]
|
||||
|
||||
in_corrid = pb_utils.get_input_tensor_by_name(request, "CORRID")
|
||||
corrid = in_corrid.as_numpy()[0][0]
|
||||
|
||||
in_end = pb_utils.get_input_tensor_by_name(request, "END")
|
||||
end = in_end.as_numpy()[0][0]
|
||||
|
||||
batch_end.append(end)
|
||||
batch_corrid.append(corrid)
|
||||
|
||||
if start:
|
||||
self.cif_search_cache[corrid] = CIFSearch()
|
||||
self.start[corrid] = 1
|
||||
if end:
|
||||
self.cif_search_cache[corrid].cache["last_chunk"] = True
|
||||
|
||||
acoustic, acoustic_len = self.cif_search_cache[corrid].infer(hidden, alphas)
|
||||
batch_result[corrid] = ''
|
||||
if acoustic.shape[1] == 0:
|
||||
continue
|
||||
else:
|
||||
qualified_corrid.append(corrid)
|
||||
input_tensor0 = pb_utils.Tensor("enc", hidden)
|
||||
input_tensor1 = pb_utils.Tensor("enc_len", np.array([hidden_len], dtype=np.int32))
|
||||
input_tensor2 = pb_utils.Tensor("acoustic_embeds", acoustic)
|
||||
input_tensor3 = pb_utils.Tensor("acoustic_embeds_len", np.array([acoustic_len], dtype=np.int32))
|
||||
input_tensors = [input_tensor0, input_tensor1, input_tensor2, input_tensor3]
|
||||
|
||||
if self.start[corrid] and end:
|
||||
flag = 3
|
||||
elif end:
|
||||
flag = 2
|
||||
elif self.start[corrid]:
|
||||
flag = 1
|
||||
self.start[corrid] = 0
|
||||
else:
|
||||
flag = 0
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='decoder',
|
||||
requested_output_names=['sample_ids'],
|
||||
inputs=input_tensors,
|
||||
request_id='',
|
||||
correlation_id=corrid,
|
||||
flags=flag
|
||||
)
|
||||
inference_response_awaits.append(inference_request.async_exec())
|
||||
|
||||
inference_responses = await asyncio.gather(*inference_response_awaits)
|
||||
|
||||
for index_corrid, inference_response in zip(qualified_corrid, inference_responses):
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
else:
|
||||
sample_ids = pb_utils.get_output_tensor_by_name(inference_response, 'sample_ids')
|
||||
token_ids = from_dlpack(sample_ids.to_dlpack()).cpu().numpy()[0]
|
||||
|
||||
# Change integer-ids to tokens
|
||||
tokens = [self.vocab_dict[token_id] for token_id in token_ids]
|
||||
batch_result[index_corrid] = "".join(tokens)
|
||||
|
||||
for i, index_corrid in enumerate(batch_corrid):
|
||||
sent = np.array([batch_result[index_corrid]])
|
||||
out0 = pb_utils.Tensor("transcripts", sent.astype(self.out0_dtype))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[out0])
|
||||
responses.append(inference_response)
|
||||
|
||||
if batch_end[i]:
|
||||
del self.cif_search_cache[index_corrid]
|
||||
del self.start[index_corrid]
|
||||
|
||||
return responses
|
||||
|
||||
def finalize(self):
|
||||
"""`finalize` is called only once when the model is being unloaded.
|
||||
Implementing `finalize` function is optional. This function allows
|
||||
the model to perform any necessary clean ups before exit.
|
||||
"""
|
||||
print('Cleaning up...')
|
||||
|
||||
111
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/config.pbtxt
Executable file
111
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/config.pbtxt
Executable file
@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
name: "cif_search"
|
||||
backend: "python"
|
||||
max_batch_size: 128
|
||||
|
||||
sequence_batching{
|
||||
max_sequence_idle_microseconds: 15000000
|
||||
oldest {
|
||||
max_candidate_sequences: 1024
|
||||
preferred_batch_size: [32, 64, 128]
|
||||
}
|
||||
control_input [
|
||||
{
|
||||
name: "START",
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_START
|
||||
fp32_false_true: [0, 1]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: "READY"
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_READY
|
||||
fp32_false_true: [0, 1]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: "CORRID",
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_CORRID
|
||||
data_type: TYPE_UINT64
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: "END",
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_END
|
||||
fp32_false_true: [0, 1]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
parameters [
|
||||
{
|
||||
key: "vocabulary",
|
||||
value: { string_value: "model_repo_paraformer_large_online/feature_extractor/config.yaml"}
|
||||
},
|
||||
{ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
||||
value: {string_value:"no"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "enc"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 512]
|
||||
},
|
||||
{
|
||||
name: "enc_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: 'alphas'
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "transcripts"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 6
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
274
runtime/triton_gpu/model_repo_paraformer_large_online/decoder/config.pbtxt
Executable file
274
runtime/triton_gpu/model_repo_paraformer_large_online/decoder/config.pbtxt
Executable file
@ -0,0 +1,274 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
name: "decoder"
|
||||
backend: "onnxruntime"
|
||||
default_model_filename: "decoder.onnx"
|
||||
|
||||
max_batch_size: 128
|
||||
|
||||
sequence_batching{
|
||||
max_sequence_idle_microseconds: 15000000
|
||||
oldest {
|
||||
max_candidate_sequences: 1024
|
||||
preferred_batch_size: [16, 32, 64]
|
||||
}
|
||||
control_input [
|
||||
]
|
||||
state [
|
||||
{
|
||||
input_name: "in_cache_0"
|
||||
output_name: "out_cache_0"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_1"
|
||||
output_name: "out_cache_1"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_2"
|
||||
output_name: "out_cache_2"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_3"
|
||||
output_name: "out_cache_3"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_4"
|
||||
output_name: "out_cache_4"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_5"
|
||||
output_name: "out_cache_5"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_6"
|
||||
output_name: "out_cache_6"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_7"
|
||||
output_name: "out_cache_7"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_8"
|
||||
output_name: "out_cache_8"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_9"
|
||||
output_name: "out_cache_9"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_10"
|
||||
output_name: "out_cache_10"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_11"
|
||||
output_name: "out_cache_11"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_12"
|
||||
output_name: "out_cache_12"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_13"
|
||||
output_name: "out_cache_13"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_14"
|
||||
output_name: "out_cache_14"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "in_cache_15"
|
||||
output_name: "out_cache_15"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10 ]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 512, 10]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
input [
|
||||
{
|
||||
name: "enc"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 512]
|
||||
},
|
||||
{
|
||||
name: "enc_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "acoustic_embeds"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 512]
|
||||
},
|
||||
{
|
||||
name: "acoustic_embeds_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
reshape: { shape: [ ] }
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "logits"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 8404]
|
||||
},
|
||||
{
|
||||
name: "sample_ids"
|
||||
data_type: TYPE_INT64
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
|
||||
77
runtime/triton_gpu/model_repo_paraformer_large_online/encoder/config.pbtxt
Executable file
77
runtime/triton_gpu/model_repo_paraformer_large_online/encoder/config.pbtxt
Executable file
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
name: "encoder"
|
||||
backend: "onnxruntime"
|
||||
default_model_filename: "model.onnx"
|
||||
|
||||
max_batch_size: 128
|
||||
|
||||
|
||||
sequence_batching{
|
||||
max_sequence_idle_microseconds: 15000000
|
||||
oldest {
|
||||
max_candidate_sequences: 1024
|
||||
preferred_batch_size: [32, 64, 128]
|
||||
max_queue_delay_microseconds: 300
|
||||
}
|
||||
control_input [
|
||||
]
|
||||
state [
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
input [
|
||||
{
|
||||
name: "speech"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 560]
|
||||
},
|
||||
{
|
||||
name: "speech_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
reshape: { shape: [ ] }
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "enc"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 512]
|
||||
},
|
||||
{
|
||||
name: "enc_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "alphas"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,221 @@
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
# Modified from NVIDIA(https://github.com/wenet-e2e/wenet/blob/main/runtime/gpu/
|
||||
# model_repo_stateful/feature_extractor/1/model.py)
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from torch.utils.dlpack import from_dlpack
|
||||
import torch
|
||||
import kaldifeat
|
||||
from typing import List
|
||||
import json
|
||||
import numpy as np
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LimitedDict(OrderedDict):
|
||||
def __init__(self, max_length):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if len(self) >= self.max_length:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class Fbank(torch.nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(Fbank, self).__init__()
|
||||
self.fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
def forward(self, waves: List[torch.Tensor]):
|
||||
return self.fbank(waves)
|
||||
|
||||
|
||||
class Feat(object):
|
||||
def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device='cpu'):
|
||||
self.seqid = seqid
|
||||
self.sample_rate = sample_rate
|
||||
self.wav = torch.tensor([], device=device)
|
||||
self.offset = int(offset_ms / 1000 * sample_rate)
|
||||
self.frames = None
|
||||
self.frame_stride = int(frame_stride)
|
||||
self.device = device
|
||||
self.lfr_m = 7
|
||||
|
||||
def add_wavs(self, wav: torch.tensor):
|
||||
wav = wav.to(self.device)
|
||||
self.wav = torch.cat((self.wav, wav), axis=0)
|
||||
|
||||
def get_seg_wav(self):
|
||||
seg = self.wav[:]
|
||||
self.wav = self.wav[-self.offset:]
|
||||
return seg
|
||||
|
||||
def add_frames(self, frames: torch.tensor):
|
||||
"""
|
||||
frames: seq_len x feat_sz
|
||||
"""
|
||||
if self.frames is None:
|
||||
self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1),
|
||||
frames), axis=0)
|
||||
else:
|
||||
self.frames = torch.cat([self.frames, frames], axis=0)
|
||||
|
||||
def get_frames(self, num_frames: int):
|
||||
seg = self.frames[0: num_frames]
|
||||
self.frames = self.frames[self.frame_stride:]
|
||||
return seg
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Your Python model must use the same class name. Every Python model
|
||||
that is created must have "TritonPythonModel" as the class name.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""`initialize` is called only once when the model is being loaded.
|
||||
Implementing `initialize` function is optional. This function allows
|
||||
the model to initialize any state associated with this model.
|
||||
Parameters
|
||||
----------
|
||||
args : dict
|
||||
Both keys and values are strings. The dictionary keys and values are:
|
||||
* model_config: A JSON string containing the model configuration
|
||||
* model_instance_kind: A string containing model instance kind
|
||||
* model_instance_device_id: A string containing model instance device ID
|
||||
* model_repository: Model repository path
|
||||
* model_version: Model version
|
||||
* model_name: Model name
|
||||
"""
|
||||
self.model_config = model_config = json.loads(args['model_config'])
|
||||
self.max_batch_size = max(model_config["max_batch_size"], 1)
|
||||
|
||||
if "GPU" in model_config["instance_group"][0]["kind"]:
|
||||
self.device = "cuda"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
|
||||
# Get OUTPUT0 configuration
|
||||
output0_config = pb_utils.get_output_config_by_name(
|
||||
model_config, "speech")
|
||||
# Convert Triton types to numpy types
|
||||
self.output0_dtype = pb_utils.triton_string_to_numpy(
|
||||
output0_config['data_type'])
|
||||
|
||||
if self.output0_dtype == np.float32:
|
||||
self.dtype = torch.float32
|
||||
else:
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.feature_size = output0_config['dims'][-1]
|
||||
self.decoding_window = output0_config['dims'][-2]
|
||||
|
||||
params = self.model_config['parameters']
|
||||
for li in params.items():
|
||||
key, value = li
|
||||
value = value["string_value"]
|
||||
if key == "config_path":
|
||||
with open(str(value), 'rb') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.frame_opts.dither = 0.0
|
||||
opts.frame_opts.window_type = config['frontend_conf']['window']
|
||||
opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
|
||||
opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
|
||||
opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
|
||||
opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
|
||||
opts.device = torch.device(self.device)
|
||||
self.opts = opts
|
||||
self.feature_extractor = Fbank(self.opts)
|
||||
|
||||
self.seq_feat = LimitedDict(1024)
|
||||
chunk_size_s = float(params["chunk_size_s"]["string_value"])
|
||||
|
||||
sample_rate = opts.frame_opts.samp_freq
|
||||
frame_shift_ms = opts.frame_opts.frame_shift_ms
|
||||
frame_length_ms = opts.frame_opts.frame_length_ms
|
||||
|
||||
self.chunk_size = int(chunk_size_s * sample_rate)
|
||||
self.frame_stride = (chunk_size_s * 1000) // frame_shift_ms
|
||||
self.offset_ms = self.get_offset(frame_length_ms, frame_shift_ms)
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def get_offset(self, frame_length_ms, frame_shift_ms):
|
||||
offset_ms = 0
|
||||
while offset_ms + frame_shift_ms < frame_length_ms:
|
||||
offset_ms += frame_shift_ms
|
||||
return offset_ms
|
||||
|
||||
def execute(self, requests):
|
||||
"""`execute` must be implemented in every Python model. `execute`
|
||||
function receives a list of pb_utils.InferenceRequest as the only
|
||||
argument. This function is called when an inference is requested
|
||||
for this model.
|
||||
Parameters
|
||||
----------
|
||||
requests : list
|
||||
A list of pb_utils.InferenceRequest
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of pb_utils.InferenceResponse. The length of this list must
|
||||
be the same as `requests`
|
||||
"""
|
||||
total_waves = []
|
||||
responses = []
|
||||
batch_seqid = []
|
||||
end_seqid = {}
|
||||
for request in requests:
|
||||
input0 = pb_utils.get_input_tensor_by_name(request, "wav")
|
||||
wav = from_dlpack(input0.to_dlpack())[0]
|
||||
# input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens")
|
||||
# wav_len = from_dlpack(input1.to_dlpack())[0]
|
||||
wav_len = len(wav)
|
||||
if wav_len < self.chunk_size:
|
||||
temp = torch.zeros(self.chunk_size, dtype=torch.float32,
|
||||
device=self.device)
|
||||
temp[0:wav_len] = wav[:]
|
||||
wav = temp
|
||||
|
||||
in_start = pb_utils.get_input_tensor_by_name(request, "START")
|
||||
start = in_start.as_numpy()[0][0]
|
||||
in_ready = pb_utils.get_input_tensor_by_name(request, "READY")
|
||||
ready = in_ready.as_numpy()[0][0]
|
||||
in_corrid = pb_utils.get_input_tensor_by_name(request, "CORRID")
|
||||
corrid = in_corrid.as_numpy()[0][0]
|
||||
in_end = pb_utils.get_input_tensor_by_name(request, "END")
|
||||
end = in_end.as_numpy()[0][0]
|
||||
|
||||
if start:
|
||||
self.seq_feat[corrid] = Feat(corrid, self.offset_ms,
|
||||
self.sample_rate,
|
||||
self.frame_stride,
|
||||
self.device)
|
||||
if ready:
|
||||
self.seq_feat[corrid].add_wavs(wav)
|
||||
|
||||
batch_seqid.append(corrid)
|
||||
if end:
|
||||
end_seqid[corrid] = 1
|
||||
|
||||
wav = self.seq_feat[corrid].get_seg_wav() * 32768
|
||||
total_waves.append(wav)
|
||||
features = self.feature_extractor(total_waves)
|
||||
for corrid, frames in zip(batch_seqid, features):
|
||||
self.seq_feat[corrid].add_frames(frames)
|
||||
speech = self.seq_feat[corrid].get_frames(self.decoding_window)
|
||||
out_tensor0 = pb_utils.Tensor("speech", torch.unsqueeze(speech, 0).to("cpu").numpy())
|
||||
output_tensors = [out_tensor0]
|
||||
response = pb_utils.InferenceResponse(output_tensors=output_tensors)
|
||||
responses.append(response)
|
||||
if corrid in end_seqid:
|
||||
del self.seq_feat[corrid]
|
||||
return responses
|
||||
|
||||
def finalize(self):
|
||||
print("Remove feature extractor!")
|
||||
@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
name: "feature_extractor"
|
||||
backend: "python"
|
||||
max_batch_size: 128
|
||||
|
||||
parameters [
|
||||
{
|
||||
key: "chunk_size_s",
|
||||
value: { string_value: "0.6"}
|
||||
},
|
||||
{
|
||||
key: "config_path"
|
||||
value: { string_value: "model_repo_paraformer_large_online/feature_extractor/config.yaml"}
|
||||
}
|
||||
]
|
||||
|
||||
sequence_batching{
|
||||
max_sequence_idle_microseconds: 15000000
|
||||
oldest {
|
||||
max_candidate_sequences: 1024
|
||||
preferred_batch_size: [32, 64, 128]
|
||||
max_queue_delay_microseconds: 300
|
||||
}
|
||||
control_input [
|
||||
{
|
||||
name: "START",
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_START
|
||||
fp32_false_true: [0, 1]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: "READY"
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_READY
|
||||
fp32_false_true: [0, 1]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: "CORRID",
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_CORRID
|
||||
data_type: TYPE_UINT64
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: "END",
|
||||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_END
|
||||
fp32_false_true: [0, 1]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
input [
|
||||
{
|
||||
name: "wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "wav_lens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "speech"
|
||||
data_type: TYPE_FP32
|
||||
dims: [61, 80] # 80
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
|
||||
8639
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/config.yaml
Executable file
8639
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/config.yaml
Executable file
File diff suppressed because it is too large
Load Diff
8
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/am.mvn
Executable file
8
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/am.mvn
Executable file
File diff suppressed because one or more lines are too long
@ -0,0 +1,85 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
name: "lfr_cmvn_pe"
|
||||
backend: "onnxruntime"
|
||||
default_model_filename: "lfr_cmvn_pe.onnx"
|
||||
|
||||
max_batch_size: 128
|
||||
|
||||
sequence_batching{
|
||||
max_sequence_idle_microseconds: 15000000
|
||||
oldest {
|
||||
max_candidate_sequences: 1024
|
||||
preferred_batch_size: [32, 64, 128]
|
||||
max_queue_delay_microseconds: 300
|
||||
}
|
||||
control_input [
|
||||
]
|
||||
state [
|
||||
{
|
||||
input_name: "cache"
|
||||
output_name: "r_cache"
|
||||
data_type: TYPE_FP32
|
||||
dims: [10, 560]
|
||||
initial_state: {
|
||||
data_type: TYPE_FP32
|
||||
dims: [10, 560]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
},
|
||||
{
|
||||
input_name: "offset"
|
||||
output_name: "r_offset"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
initial_state: {
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
zero_data: true
|
||||
name: "initial state"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
input [
|
||||
{
|
||||
name: "chunk_xs"
|
||||
data_type: TYPE_FP32
|
||||
dims: [61, 80]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "chunk_xs_out"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1, 560]
|
||||
},
|
||||
{
|
||||
name: "chunk_xs_out_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
|
||||
@ -0,0 +1,131 @@
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LFR_CMVN_PE(torch.nn.Module):
|
||||
def __init__(self,
|
||||
mean: torch.Tensor,
|
||||
istd: torch.Tensor,
|
||||
m: int = 7,
|
||||
n: int = 6,
|
||||
max_len: int = 5000,
|
||||
encoder_input_size: int = 560,
|
||||
encoder_output_size: int = 512):
|
||||
super().__init__()
|
||||
|
||||
# LRF
|
||||
self.m = m
|
||||
self.n = n
|
||||
self.subsample = (m - 1) // 2
|
||||
|
||||
# CMVN
|
||||
assert mean.shape == istd.shape
|
||||
# The buffer can be accessed from this module using self.mean
|
||||
self.register_buffer("mean", mean)
|
||||
self.register_buffer("istd", istd)
|
||||
|
||||
# PE
|
||||
self.encoder_input_size = encoder_input_size
|
||||
self.encoder_output_size = encoder_output_size
|
||||
self.max_len = max_len
|
||||
self.pe = torch.zeros(self.max_len, self.encoder_input_size)
|
||||
position = torch.arange(0, self.max_len,
|
||||
dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange((self.encoder_input_size/2), dtype=torch.float32) *
|
||||
-(math.log(10000.0) / (self.encoder_input_size/2-1)))
|
||||
self.pe[:, 0::1] = torch.cat((torch.sin(position * div_term), torch.cos(position * div_term)), dim=1)
|
||||
|
||||
def forward(self, x, cache, offset):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (batch, max_len, feat_dim)
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): normalized feature
|
||||
"""
|
||||
B, _, D = x.size()
|
||||
x = x.unfold(1, self.m, step=self.n).transpose(2, 3)
|
||||
x = x.view(B, -1, D * self.m)
|
||||
|
||||
x = (x + self.mean) * self.istd
|
||||
x = x * (self.encoder_output_size ** 0.5)
|
||||
|
||||
index = offset + torch.arange(1, x.size(1)+1).to(dtype=torch.int32)
|
||||
pos_emb = F.embedding(index, self.pe) # B X T X d_model
|
||||
r_cache = x + pos_emb
|
||||
|
||||
r_x = torch.cat((cache, r_cache), dim=1)
|
||||
r_offset = offset + x.size(1)
|
||||
r_x_len = torch.ones((B, 1), dtype=torch.int32) * r_x.size(1)
|
||||
|
||||
return r_x, r_x_len, r_cache, r_offset
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file):
|
||||
with open(cmvn_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == '<AddShift>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
add_shift_line = line_item[3:(len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == '<Rescale>':
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == '<LearnRateCoef>':
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
|
||||
means = np.array(means_list).astype(np.float32)
|
||||
vars = np.array(vars_list).astype(np.float32)
|
||||
means = torch.from_numpy(means)
|
||||
vars = torch.from_numpy(vars)
|
||||
return means, vars
|
||||
|
||||
if __name__ == "__main__":
|
||||
means, vars = load_cmvn("am.mvn")
|
||||
means = torch.tile(means, (10, 1))
|
||||
vars = torch.tile(vars, (10, 1))
|
||||
|
||||
model = LFR_CMVN_PE(means, vars)
|
||||
model.eval()
|
||||
|
||||
all_names = ['chunk_xs', 'cache', 'offset', 'chunk_xs_out', 'chunk_xs_out_len', 'r_cache', 'r_offset']
|
||||
dynamic_axes = {}
|
||||
|
||||
for name in all_names:
|
||||
dynamic_axes[name] = {0: 'B'}
|
||||
|
||||
input_data1 = torch.randn(4, 61, 80).to(torch.float32)
|
||||
input_data2 = torch.randn(4, 10, 560).to(torch.float32)
|
||||
input_data3 = torch.randn(4, 1).to(torch.int32)
|
||||
|
||||
onnx_path = "./1/lfr_cmvn_pe.onnx"
|
||||
torch.onnx.export(model,
|
||||
(input_data1, input_data2, input_data3),
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=11,
|
||||
do_constant_folding=True,
|
||||
input_names=['chunk_xs', 'cache', 'offset'],
|
||||
output_names=['chunk_xs_out', 'chunk_xs_out_len', 'r_cache', 'r_offset'],
|
||||
dynamic_axes=dynamic_axes,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
print("export to onnx model succeed!")
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Created on 2024-01-01
|
||||
# Author: GuAn Zhu
|
||||
|
||||
name: "streaming_paraformer"
|
||||
platform: "ensemble"
|
||||
max_batch_size: 128 #MAX_BATCH
|
||||
|
||||
input [
|
||||
{
|
||||
name: "WAV"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "WAV_LENS"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "TRANSCRIPTS"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
|
||||
ensemble_scheduling {
|
||||
step [
|
||||
{
|
||||
model_name: "feature_extractor"
|
||||
model_version: -1
|
||||
input_map {
|
||||
key: "wav"
|
||||
value: "WAV"
|
||||
}
|
||||
input_map {
|
||||
key: "wav_lens"
|
||||
value: "WAV_LENS"
|
||||
}
|
||||
output_map {
|
||||
key: "speech"
|
||||
value: "SPEECH"
|
||||
}
|
||||
},
|
||||
{
|
||||
model_name: "lfr_cmvn_pe"
|
||||
model_version: -1
|
||||
input_map {
|
||||
key: "chunk_xs"
|
||||
value: "SPEECH"
|
||||
}
|
||||
output_map {
|
||||
key: "chunk_xs_out"
|
||||
value: "CHUNK_XS_OUT"
|
||||
}
|
||||
output_map {
|
||||
key: "chunk_xs_out_len"
|
||||
value: "CHUNK_XS_OUT_LEN"
|
||||
}
|
||||
},
|
||||
{
|
||||
model_name: "encoder"
|
||||
model_version: -1
|
||||
input_map {
|
||||
key: "speech"
|
||||
value: "CHUNK_XS_OUT"
|
||||
}
|
||||
input_map {
|
||||
key: "speech_lengths"
|
||||
value: "CHUNK_XS_OUT_LEN"
|
||||
}
|
||||
output_map {
|
||||
key: "enc"
|
||||
value: "ENC"
|
||||
}
|
||||
output_map {
|
||||
key: "enc_len"
|
||||
value: "ENC_LEN"
|
||||
}
|
||||
output_map {
|
||||
key: "alphas"
|
||||
value: "ALPHAS"
|
||||
}
|
||||
},
|
||||
{
|
||||
model_name: "cif_search"
|
||||
model_version: -1
|
||||
input_map {
|
||||
key: "enc"
|
||||
value: "ENC"
|
||||
}
|
||||
input_map {
|
||||
key: "enc_len"
|
||||
value: "ENC_LEN"
|
||||
}
|
||||
input_map {
|
||||
key: "alphas"
|
||||
value: "ALPHAS"
|
||||
}
|
||||
output_map {
|
||||
key: "transcripts"
|
||||
value: "TRANSCRIPTS"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -122,7 +122,7 @@ class WebsocketClient {
|
||||
|
||||
// This method will block until the connection is complete
|
||||
void run(const std::string& uri, const std::vector<string>& wav_list,
|
||||
const std::vector<string>& wav_ids, std::string asr_mode,
|
||||
const std::vector<string>& wav_ids, int audio_fs, std::string asr_mode,
|
||||
std::vector<int> chunk_size, const std::unordered_map<std::string, int>& hws_map,
|
||||
bool is_record=false, int use_itn=1) {
|
||||
// Create a new connection to the given URI
|
||||
@ -148,7 +148,7 @@ class WebsocketClient {
|
||||
if(is_record){
|
||||
send_rec_data(asr_mode, chunk_size, hws_map, use_itn);
|
||||
}else{
|
||||
send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size, hws_map, use_itn);
|
||||
send_wav_data(wav_list[0], wav_ids[0], audio_fs, asr_mode, chunk_size, hws_map, use_itn);
|
||||
}
|
||||
|
||||
WaitABit();
|
||||
@ -183,20 +183,20 @@ class WebsocketClient {
|
||||
m_done = true;
|
||||
}
|
||||
// send wav to server
|
||||
void send_wav_data(string wav_path, string wav_id, std::string asr_mode,
|
||||
void send_wav_data(string wav_path, string wav_id, int audio_fs, std::string asr_mode,
|
||||
std::vector<int> chunk_vector, const std::unordered_map<std::string, int>& hws_map,
|
||||
int use_itn) {
|
||||
uint64_t count = 0;
|
||||
std::stringstream val;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
int32_t sampling_rate = 16000;
|
||||
int32_t sampling_rate = audio_fs;
|
||||
std::string wav_format = "pcm";
|
||||
if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
|
||||
int32_t sampling_rate = -1;
|
||||
if (!audio.LoadWav(wav_path.c_str(), &sampling_rate)) return;
|
||||
if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false))
|
||||
return;
|
||||
} else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
|
||||
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate)) return;
|
||||
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return;
|
||||
} else {
|
||||
wav_format = "others";
|
||||
if (!audio.LoadOthers2Char(wav_path.c_str())) return;
|
||||
@ -238,6 +238,7 @@ class WebsocketClient {
|
||||
jsonbegin["chunk_size"] = chunk_size;
|
||||
jsonbegin["wav_name"] = wav_id;
|
||||
jsonbegin["wav_format"] = wav_format;
|
||||
jsonbegin["audio_fs"] = sampling_rate;
|
||||
jsonbegin["is_speaking"] = true;
|
||||
jsonbegin["itn"] = true;
|
||||
if(use_itn == 0){
|
||||
@ -360,6 +361,7 @@ class WebsocketClient {
|
||||
}
|
||||
websocketpp::lib::error_code ec;
|
||||
|
||||
float sample_rate = 16000;
|
||||
nlohmann::json jsonbegin;
|
||||
nlohmann::json chunk_size = nlohmann::json::array();
|
||||
chunk_size.push_back(chunk_vector[0]);
|
||||
@ -369,6 +371,7 @@ class WebsocketClient {
|
||||
jsonbegin["chunk_size"] = chunk_size;
|
||||
jsonbegin["wav_name"] = "record";
|
||||
jsonbegin["wav_format"] = "pcm";
|
||||
jsonbegin["audio_fs"] = sample_rate;
|
||||
jsonbegin["is_speaking"] = true;
|
||||
jsonbegin["itn"] = true;
|
||||
if(use_itn == 0){
|
||||
@ -408,7 +411,6 @@ class WebsocketClient {
|
||||
|
||||
param.suggestedLatency = info->defaultLowInputLatency;
|
||||
param.hostApiSpecificStreamInfo = nullptr;
|
||||
float sample_rate = 16000;
|
||||
|
||||
PaStream *stream;
|
||||
std::vector<float> buffer;
|
||||
@ -473,6 +475,10 @@ class WebsocketClient {
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
SetConsoleOutputCP(65001);
|
||||
#endif
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
@ -486,6 +492,7 @@ int main(int argc, char* argv[]) {
|
||||
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
|
||||
"asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
|
||||
false, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs_("", "audio-fs", "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<int> record_(
|
||||
"", "record",
|
||||
"record is 1 means use record", false, 0,
|
||||
@ -511,6 +518,7 @@ int main(int argc, char* argv[]) {
|
||||
cmd.add(server_ip_);
|
||||
cmd.add(port_);
|
||||
cmd.add(wav_path_);
|
||||
cmd.add(audio_fs_);
|
||||
cmd.add(asr_mode_);
|
||||
cmd.add(record_);
|
||||
cmd.add(chunk_size_);
|
||||
@ -558,6 +566,7 @@ int main(int argc, char* argv[]) {
|
||||
funasr::ExtractHws(hotword_path, hws_map);
|
||||
}
|
||||
|
||||
int audio_fs = audio_fs_.getValue();
|
||||
if(is_record == 1){
|
||||
std::vector<string> tmp_wav_list;
|
||||
std::vector<string> tmp_wav_ids;
|
||||
@ -567,11 +576,11 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
|
||||
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, true, use_itn);
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
|
||||
} else {
|
||||
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
|
||||
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, true, use_itn);
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
|
||||
}
|
||||
|
||||
}else{
|
||||
@ -612,17 +621,17 @@ int main(int argc, char* argv[]) {
|
||||
tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
|
||||
|
||||
client_threads.emplace_back(
|
||||
[uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl, hws_map, use_itn]() {
|
||||
[uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, is_ssl, hws_map, use_itn]() {
|
||||
if (is_ssl == 1) {
|
||||
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
|
||||
|
||||
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
|
||||
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, false, use_itn);
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
|
||||
} else {
|
||||
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
|
||||
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, false, use_itn);
|
||||
c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -98,11 +98,12 @@ class WebsocketClient {
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text:
|
||||
total_recv=total_recv+1;
|
||||
LOG(INFO)<< "Thread: " << this_thread::get_id() <<", on_message = " << payload;
|
||||
LOG(INFO)<< "Thread: " << this_thread::get_id() << ", total_recv=" << total_recv << " total_send=" <<total_send;
|
||||
if(total_recv==total_send)
|
||||
LOG(INFO)<< "Thread: " << this_thread::get_id() << ", total_recv=" << total_recv <<", on_message = " << payload;
|
||||
std::unique_lock<std::mutex> lock(msg_lock);
|
||||
cv.notify_one();
|
||||
if(close_client)
|
||||
{
|
||||
LOG(INFO)<< "Thread: " << this_thread::get_id() << ", close client";
|
||||
LOG(INFO)<< "Thread: " << this_thread::get_id() << ", close client thread";
|
||||
websocketpp::lib::error_code ec;
|
||||
m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
|
||||
if (ec){
|
||||
@ -114,7 +115,7 @@ class WebsocketClient {
|
||||
|
||||
// This method will block until the connection is complete
|
||||
void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids,
|
||||
const std::unordered_map<std::string, int>& hws_map, int use_itn=1) {
|
||||
int audio_fs, const std::unordered_map<std::string, int>& hws_map, int use_itn=1) {
|
||||
// Create a new connection to the given URI
|
||||
websocketpp::lib::error_code ec;
|
||||
typename websocketpp::client<T>::connection_ptr con =
|
||||
@ -141,14 +142,17 @@ class WebsocketClient {
|
||||
if (i >= wav_list.size()) {
|
||||
break;
|
||||
}
|
||||
if (total_send !=0){
|
||||
std::unique_lock<std::mutex> lock(msg_lock);
|
||||
cv.wait(lock);
|
||||
}
|
||||
total_send += 1;
|
||||
send_wav_data(wav_list[i], wav_ids[i], hws_map, send_hotword, use_itn);
|
||||
send_wav_data(wav_list[i], wav_ids[i], audio_fs, hws_map, send_hotword, use_itn);
|
||||
if(send_hotword){
|
||||
send_hotword = false;
|
||||
}
|
||||
}
|
||||
WaitABit();
|
||||
|
||||
close_client = true;
|
||||
asio_thread.join();
|
||||
|
||||
}
|
||||
@ -180,21 +184,20 @@ class WebsocketClient {
|
||||
m_done = true;
|
||||
}
|
||||
// send wav to server
|
||||
void send_wav_data(string wav_path, string wav_id,
|
||||
void send_wav_data(string wav_path, string wav_id, int audio_fs,
|
||||
const std::unordered_map<std::string, int>& hws_map,
|
||||
bool send_hotword, bool use_itn) {
|
||||
uint64_t count = 0;
|
||||
std::stringstream val;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
int32_t sampling_rate = 16000;
|
||||
int32_t sampling_rate = audio_fs;
|
||||
std::string wav_format = "pcm";
|
||||
if(funasr::IsTargetFile(wav_path.c_str(), "wav")){
|
||||
int32_t sampling_rate = -1;
|
||||
if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
|
||||
return ;
|
||||
}else if(funasr::IsTargetFile(wav_path.c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
|
||||
if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
|
||||
if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false))
|
||||
return;
|
||||
} else if(funasr::IsTargetFile(wav_path.c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false))
|
||||
return ;
|
||||
}else{
|
||||
wav_format = "others";
|
||||
@ -237,6 +240,7 @@ class WebsocketClient {
|
||||
jsonbegin["chunk_interval"] = 10;
|
||||
jsonbegin["wav_name"] = wav_id;
|
||||
jsonbegin["wav_format"] = wav_format;
|
||||
jsonbegin["audio_fs"] = sampling_rate;
|
||||
jsonbegin["itn"] = true;
|
||||
if(use_itn == 0){
|
||||
jsonbegin["itn"] = false;
|
||||
@ -334,14 +338,20 @@ class WebsocketClient {
|
||||
private:
|
||||
websocketpp::connection_hdl m_hdl;
|
||||
websocketpp::lib::mutex m_lock;
|
||||
websocketpp::lib::mutex msg_lock;
|
||||
websocketpp::lib::condition_variable cv;
|
||||
bool m_open;
|
||||
bool m_done;
|
||||
bool close_client=false;
|
||||
int total_send=0;
|
||||
int total_recv=0;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
SetConsoleOutputCP(65001);
|
||||
#endif
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
@ -352,6 +362,7 @@ int main(int argc, char* argv[]) {
|
||||
TCLAP::ValueArg<std::string> wav_path_("", "wav-path",
|
||||
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
|
||||
true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs_("", "audio-fs", "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num",
|
||||
false, 1, "int");
|
||||
TCLAP::ValueArg<int> is_ssl_(
|
||||
@ -366,6 +377,7 @@ int main(int argc, char* argv[]) {
|
||||
cmd.add(server_ip_);
|
||||
cmd.add(port_);
|
||||
cmd.add(wav_path_);
|
||||
cmd.add(audio_fs_);
|
||||
cmd.add(thread_num_);
|
||||
cmd.add(is_ssl_);
|
||||
cmd.add(use_itn_);
|
||||
@ -420,18 +432,19 @@ int main(int argc, char* argv[]) {
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
int audio_fs = audio_fs_.getValue();
|
||||
for (size_t i = 0; i < threads_num; i++) {
|
||||
client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hws_map, use_itn]() {
|
||||
client_threads.emplace_back([uri, wav_list, wav_ids, audio_fs, is_ssl, hws_map, use_itn]() {
|
||||
if (is_ssl == 1) {
|
||||
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
|
||||
|
||||
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
|
||||
|
||||
c.run(uri, wav_list, wav_ids, hws_map, use_itn);
|
||||
c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn);
|
||||
} else {
|
||||
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
|
||||
|
||||
c.run(uri, wav_list, wav_ids, hws_map, use_itn);
|
||||
c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
// hotwords
|
||||
std::unordered_map<std::string, int> hws_map_;
|
||||
int fst_inc_wts_=20;
|
||||
float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
using namespace std;
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
@ -25,6 +26,10 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
SetConsoleOutputCP(65001);
|
||||
#endif
|
||||
try {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
@ -116,6 +121,14 @@ int main(int argc, char* argv[]) {
|
||||
"connection",
|
||||
false, "../../../ssl_key/server.key", "string");
|
||||
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
|
||||
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
|
||||
"the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
|
||||
TCLAP::ValueArg<std::string> lm_revision(
|
||||
"", "lm-revision", "LM model revision", false, "v1.0.2", "string");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD,
|
||||
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)",
|
||||
false, "/workspace/resources/hotwords.txt", "string");
|
||||
@ -124,6 +137,10 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
// add file
|
||||
cmd.add(hotword);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
|
||||
cmd.add(certfile);
|
||||
cmd.add(keyfile);
|
||||
@ -142,6 +159,8 @@ int main(int argc, char* argv[]) {
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(itn_revision);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(lm_revision);
|
||||
|
||||
cmd.add(listen_ip);
|
||||
cmd.add(port);
|
||||
@ -159,6 +178,7 @@ int main(int argc, char* argv[]) {
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(hotword, HOTWORD, model_path);
|
||||
|
||||
GetValue(offline_model_revision, "offline-model-revision", model_path);
|
||||
@ -166,6 +186,11 @@ int main(int argc, char* argv[]) {
|
||||
GetValue(vad_revision, "vad-revision", model_path);
|
||||
GetValue(punc_revision, "punc-revision", model_path);
|
||||
GetValue(itn_revision, "itn-revision", model_path);
|
||||
GetValue(lm_revision, "lm-revision", model_path);
|
||||
|
||||
global_beam_ = global_beam.getValue();
|
||||
lattice_beam_ = lattice_beam.getValue();
|
||||
am_scale_ = am_scale.getValue();
|
||||
|
||||
// Download model form Modelscope
|
||||
try {
|
||||
@ -179,6 +204,7 @@ int main(int argc, char* argv[]) {
|
||||
std::string s_punc_path = model_path[PUNC_DIR];
|
||||
std::string s_punc_quant = model_path[PUNC_QUANT];
|
||||
std::string s_itn_path = model_path[ITN_DIR];
|
||||
std::string s_lm_path = model_path[LM_DIR];
|
||||
|
||||
std::string python_cmd =
|
||||
"python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
|
||||
@ -237,11 +263,18 @@ int main(int argc, char* argv[]) {
|
||||
size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["offline-model-revision"]="v1.2.4";
|
||||
} else{
|
||||
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["offline-model-revision"]="v1.0.5";
|
||||
}
|
||||
}
|
||||
|
||||
found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["offline-model-revision"]="v1.0.5";
|
||||
}
|
||||
|
||||
found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"]="v1.0.0";
|
||||
s_itn_path="";
|
||||
s_lm_path="";
|
||||
}
|
||||
|
||||
if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
|
||||
@ -328,6 +361,49 @@ int main(int argc, char* argv[]) {
|
||||
LOG(INFO) << "ASR online model is not set, use default.";
|
||||
}
|
||||
|
||||
if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
|
||||
std::string python_cmd_lm;
|
||||
std::string down_lm_path;
|
||||
std::string down_lm_model;
|
||||
|
||||
if (access(s_lm_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["lm-revision"] + " --export False ";
|
||||
down_lm_path = s_lm_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_lm_path
|
||||
<< " from modelscope : ";
|
||||
python_cmd_lm = python_cmd + " --model-name " +
|
||||
s_lm_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["lm-revision"]
|
||||
+ " --export False ";
|
||||
down_lm_path =
|
||||
s_download_model_dir +
|
||||
"/" + s_lm_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_lm.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set local lm model path, you can ignore the errors.";
|
||||
}
|
||||
down_lm_model = down_lm_path + "/TLG.fst";
|
||||
|
||||
if (access(down_lm_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_lm_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[LM_DIR] = down_lm_path;
|
||||
LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "LM model is not set, not executed.";
|
||||
model_path[LM_DIR] = "";
|
||||
}
|
||||
|
||||
if (!s_punc_path.empty()) {
|
||||
std::string python_cmd_punc;
|
||||
std::string down_punc_path;
|
||||
|
||||
@ -26,6 +26,10 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
SetConsoleOutputCP(65001);
|
||||
#endif
|
||||
try {
|
||||
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
@ -111,7 +115,7 @@ int main(int argc, char* argv[]) {
|
||||
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
|
||||
"the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
|
||||
TCLAP::ValueArg<std::string> lm_revision(
|
||||
"", "lm-revision", "LM model revision", false, "v1.0.1", "string");
|
||||
"", "lm-revision", "LM model revision", false, "v1.0.2", "string");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD,
|
||||
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)",
|
||||
false, "/workspace/resources/hotwords.txt", "string");
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
extern std::unordered_map<std::string, int> hws_map_;
|
||||
extern int fst_inc_wts_;
|
||||
extern float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
||||
websocketpp::connection_hdl hdl,
|
||||
@ -80,6 +81,19 @@ nlohmann::json handle_result(FUNASR_RESULT result) {
|
||||
jsonresult["timestamp"] = tmp_stamp_msg;
|
||||
}
|
||||
|
||||
std::string tmp_stamp_sents = FunASRGetStampSents(result);
|
||||
if (tmp_stamp_sents != "") {
|
||||
try{
|
||||
nlohmann::json json_stamp = nlohmann::json::parse(tmp_stamp_sents);
|
||||
LOG(INFO) << "offline stamp_sents : " << json_stamp;
|
||||
jsonresult["stamp_sents"] = json_stamp;
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<< tmp_stamp_sents << e.what();
|
||||
jsonresult["stamp_sents"] = "";
|
||||
}
|
||||
}
|
||||
|
||||
return jsonresult;
|
||||
}
|
||||
// feed buffer to asr engine for decoder
|
||||
@ -96,7 +110,8 @@ void WebSocketServer::do_decoder(
|
||||
bool itn,
|
||||
int audio_fs,
|
||||
std::string wav_format,
|
||||
FUNASR_HANDLE& tpass_online_handle) {
|
||||
FUNASR_HANDLE& tpass_online_handle,
|
||||
FUNASR_DEC_HANDLE& decoder_handle) {
|
||||
// lock for each connection
|
||||
if(!tpass_online_handle){
|
||||
scoped_lock guard(thread_lock);
|
||||
@ -125,7 +140,7 @@ void WebSocketServer::do_decoder(
|
||||
subvector.data(), subvector.size(),
|
||||
punc_cache, false, audio_fs,
|
||||
wav_format, (ASR_TYPE)asr_mode_,
|
||||
hotwords_embedding, itn);
|
||||
hotwords_embedding, itn, decoder_handle);
|
||||
|
||||
} else {
|
||||
scoped_lock guard(thread_lock);
|
||||
@ -162,7 +177,7 @@ void WebSocketServer::do_decoder(
|
||||
buffer.data(), buffer.size(), punc_cache,
|
||||
is_final, audio_fs,
|
||||
wav_format, (ASR_TYPE)asr_mode_,
|
||||
hotwords_embedding, itn);
|
||||
hotwords_embedding, itn, decoder_handle);
|
||||
} else {
|
||||
scoped_lock guard(thread_lock);
|
||||
msg["access_num"]=(int)msg["access_num"]-1;
|
||||
@ -232,9 +247,12 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
||||
data_msg->msg["wav_name"] = "wav-default-id";
|
||||
data_msg->msg["mode"] = "2pass";
|
||||
data_msg->msg["itn"] = true;
|
||||
data_msg->msg["audio_fs"] = 16000;
|
||||
data_msg->msg["audio_fs"] = 16000; // default is 16k
|
||||
data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
|
||||
data_msg->msg["is_eof"]=false; // if this connection is closed
|
||||
FUNASR_DEC_HANDLE decoder_handle =
|
||||
FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_);
|
||||
data_msg->decoder_handle = decoder_handle;
|
||||
data_msg->punc_cache =
|
||||
std::make_shared<std::vector<std::vector<std::string>>>(2);
|
||||
data_msg->strand_ = std::make_shared<asio::io_context::strand>(io_decoder_);
|
||||
@ -261,6 +279,9 @@ void remove_hdl(
|
||||
// finished and avoid access freed tpass_online_handle
|
||||
unique_lock guard_decoder(*(data_msg->thread_lock));
|
||||
if (data_msg->msg["access_num"]==0 && data_msg->msg["is_eof"]==true) {
|
||||
FunWfstDecoderUnloadHwsRes(data_msg->decoder_handle);
|
||||
FunASRWfstDecoderUninit(data_msg->decoder_handle);
|
||||
data_msg->decoder_handle = nullptr;
|
||||
FunTpassOnlineUninit(data_msg->tpass_online_handle);
|
||||
data_msg->tpass_online_handle = nullptr;
|
||||
data_map.erase(hdl);
|
||||
@ -318,7 +339,7 @@ void WebSocketServer::check_and_clean_connection() {
|
||||
data_msg->msg["is_eof"]=true;
|
||||
guard_decoder.unlock();
|
||||
to_remove.push_back(hdl);
|
||||
LOG(INFO)<<"connection is closed: "<<e.what();
|
||||
LOG(INFO)<<"connection is closed.";
|
||||
|
||||
}
|
||||
iter++;
|
||||
@ -425,7 +446,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
nn_hotwords += " " + pair.first;
|
||||
LOG(INFO) << pair.first << " : " << pair.second;
|
||||
}
|
||||
// FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map);
|
||||
FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map);
|
||||
|
||||
// nn
|
||||
std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords, ASR_TWO_PASS);
|
||||
@ -477,7 +498,8 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
msg_data->msg["itn"],
|
||||
msg_data->msg["audio_fs"],
|
||||
msg_data->msg["wav_format"],
|
||||
std::ref(msg_data->tpass_online_handle)));
|
||||
std::ref(msg_data->tpass_online_handle),
|
||||
std::ref(msg_data->decoder_handle)));
|
||||
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
|
||||
}
|
||||
catch (std::exception const &e)
|
||||
@ -524,7 +546,8 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
msg_data->msg["itn"],
|
||||
msg_data->msg["audio_fs"],
|
||||
msg_data->msg["wav_format"],
|
||||
std::ref(msg_data->tpass_online_handle)));
|
||||
std::ref(msg_data->tpass_online_handle),
|
||||
std::ref(msg_data->decoder_handle)));
|
||||
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,7 +60,8 @@ typedef struct {
|
||||
FUNASR_HANDLE tpass_online_handle=NULL;
|
||||
std::string online_res = "";
|
||||
std::string tpass_res = "";
|
||||
std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order
|
||||
std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order
|
||||
FUNASR_DEC_HANDLE decoder_handle=NULL;
|
||||
} FUNASR_MESSAGE;
|
||||
|
||||
// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
|
||||
@ -123,7 +124,8 @@ class WebSocketServer {
|
||||
bool itn,
|
||||
int audio_fs,
|
||||
std::string wav_format,
|
||||
FUNASR_HANDLE& tpass_online_handle);
|
||||
FUNASR_HANDLE& tpass_online_handle,
|
||||
FUNASR_DEC_HANDLE& decoder_handle);
|
||||
|
||||
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
|
||||
|
||||
@ -72,19 +72,23 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
|
||||
int num_samples = buffer.size(); // the size of the buf
|
||||
|
||||
if (!buffer.empty() && hotwords_embedding.size() > 0) {
|
||||
std::string asr_result;
|
||||
std::string stamp_res;
|
||||
std::string asr_result="";
|
||||
std::string stamp_res="";
|
||||
std::string stamp_sents="";
|
||||
try{
|
||||
FUNASR_RESULT Result = FunOfflineInferBuffer(
|
||||
asr_handle, buffer.data(), buffer.size(), RASR_NONE, NULL,
|
||||
hotwords_embedding, audio_fs, wav_format, itn, decoder_handle);
|
||||
|
||||
asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
|
||||
stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
|
||||
FunASRFreeResult(Result);
|
||||
if (Result != NULL){
|
||||
asr_result = FunASRGetResult(Result, 0); // get decode result
|
||||
stamp_res = FunASRGetStamp(Result);
|
||||
stamp_sents = FunASRGetStampSents(Result);
|
||||
FunASRFreeResult(Result);
|
||||
} else{
|
||||
LOG(ERROR) << "FUNASR_RESULT is NULL.";
|
||||
}
|
||||
}catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
return;
|
||||
}
|
||||
|
||||
websocketpp::lib::error_code ec;
|
||||
@ -95,6 +99,16 @@ void WebSocketServer::do_decoder(const std::vector<char>& buffer,
|
||||
if(stamp_res != ""){
|
||||
jsonresult["timestamp"] = stamp_res;
|
||||
}
|
||||
if(stamp_sents != ""){
|
||||
try{
|
||||
nlohmann::json json_stamp = nlohmann::json::parse(stamp_sents);
|
||||
jsonresult["stamp_sents"] = json_stamp;
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
jsonresult["stamp_sents"] = "";
|
||||
}
|
||||
}
|
||||
jsonresult["wav_name"] = wav_name;
|
||||
|
||||
// send the json to client
|
||||
@ -144,7 +158,7 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
||||
data_msg->msg["wav_format"] = "pcm";
|
||||
data_msg->msg["wav_name"] = "wav-default-id";
|
||||
data_msg->msg["itn"] = true;
|
||||
data_msg->msg["audio_fs"] = 16000;
|
||||
data_msg->msg["audio_fs"] = 16000; // default is 16k
|
||||
data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
|
||||
data_msg->msg["is_eof"]=false;
|
||||
FUNASR_DEC_HANDLE decoder_handle =
|
||||
@ -227,7 +241,7 @@ void WebSocketServer::check_and_clean_connection() {
|
||||
data_msg->msg["is_eof"]=true;
|
||||
guard_decoder.unlock();
|
||||
to_remove.push_back(hdl);
|
||||
LOG(INFO)<<"connection is closed: "<<e.what();
|
||||
LOG(INFO)<<"connection is closed.";
|
||||
|
||||
}
|
||||
iter++;
|
||||
|
||||
@ -50,6 +50,7 @@ typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
|
||||
typedef struct {
|
||||
std::string msg="";
|
||||
std::string stamp="";
|
||||
std::string stamp_sents;
|
||||
std::string tpass_msg="";
|
||||
float snippet_time=0;
|
||||
} FUNASR_RECOG_RESULT;
|
||||
|
||||
@ -40,12 +40,12 @@ make -j 4
|
||||
### Download onnxruntime
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-win-x64-1.16.1.zip
|
||||
|
||||
Download to d:\ffmpeg-master-latest-win64-gpl-shared
|
||||
Download to d:\onnxruntime-win-x64-1.16.1
|
||||
|
||||
### Download ffmpeg
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-win64-gpl-shared.zip
|
||||
|
||||
Download to d:\onnxruntime-win-x64-1.16.1
|
||||
Download to d:\ffmpeg-master-latest-win64-gpl-shared
|
||||
|
||||
### Download openssl
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/openssl-1.1.1w.zip
|
||||
|
||||
@ -41,12 +41,12 @@ make -j 4
|
||||
### 下载 onnxruntime
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-win-x64-1.16.1.zip
|
||||
|
||||
下载并解压到 d:/ffmpeg-master-latest-win64-gpl-shared
|
||||
下载并解压到 d:/onnxruntime-win-x64-1.16.1
|
||||
|
||||
### 下载 ffmpeg
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-win64-gpl-shared.zip
|
||||
|
||||
下载并解压到 d:/onnxruntime-win-x64-1.16.1
|
||||
下载并解压到 d:/ffmpeg-master-latest-win64-gpl-shared
|
||||
|
||||
### 编译 openssl
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/openssl-1.1.1w.zip
|
||||
|
||||
@ -39,7 +39,7 @@ html设计可以参考whisper( https://openai.com/research/whisper )
|
||||
FunASR离线文件转写软件包,提供了一款功能强大的语音离线文件转写服务。拥有完整的语音识别链路,结合了语音端点检测、语音识别、标点等模型,可以将几十个小时的长音频与视频识别成带标点的文字,而且支持上百路请求同时进行转写。输出为带标点的文字,含有字级别时间戳,支持ITN与用户自定义热词等。服务端集成有ffmpeg,支持各种音视频格式输入。软件包提供有html、python、c++、java与c#等多种编程语言客户端,用户可以直接使用与进一步开发。
|
||||
|
||||
在线体验:
|
||||
https://101.37.77.25:1335/static/index.html
|
||||
https://121.43.113.106:1335/static/index.html
|
||||
|
||||
安装:
|
||||
|
||||
@ -63,7 +63,7 @@ html设计可以参考whisper( https://openai.com/research/whisper )
|
||||
FunASR实时语音听写软件包,集成了实时版本的语音端点检测模型、语音识别、语音识别、标点预测模型等。采用多模型协同,既可以实时的进行语音转文字,也可以在说话句尾用高精度转写文字修正输出,输出文字带有标点,支持多路请求。依据使用者场景不同,支持实时语音听写服务(online)、非实时一句话转写(offline)与实时与非实时一体化协同(2pass)3种服务模式。软件包提供有html、python、c++、java与c#等多种编程语言客户端,用户可以直接使用与进一步开发。
|
||||
|
||||
在线体验:
|
||||
https://101.37.77.25:1336/static/index.html
|
||||
https://121.43.113.106:1336/static/index.html
|
||||
|
||||
安装:
|
||||
|
||||
@ -94,4 +94,4 @@ npm install
|
||||
# 开发模式
|
||||
npm run dev
|
||||
# 产品模式
|
||||
npm run example
|
||||
npm run example
|
||||
|
||||
@ -43,24 +43,6 @@
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<!-- <div class="line-box">
|
||||
<div></div>
|
||||
</div>
|
||||
|
||||
<div class="ba-box">
|
||||
<ul>
|
||||
<li>
|
||||
Copyright @ 1998 - 2023 Tancant. All Rights Reserved.北科软件公司 版权所有
|
||||
</li>
|
||||
<li>
|
||||
公司地址:武汉市洪山区野芷湖西路16号创意天地10号高层13楼
|
||||
</li>
|
||||
<li>
|
||||
联系电话: 400 862 6126
|
||||
</li>
|
||||
</ul>
|
||||
</div> -->
|
||||
</footer>
|
||||
|
||||
</div>
|
||||
|
||||
@ -130,7 +130,7 @@ export default {
|
||||
{
|
||||
icon: require('./assets/images/lxwj-zxty.png'),
|
||||
title: '在线体验',
|
||||
link: 'https://101.37.77.25:1335/static/index.html'
|
||||
link: 'https://www.funasr.com:1335/static/index.html'
|
||||
},
|
||||
{
|
||||
icon: require('./assets/images/lxwj-az.png'),
|
||||
|
||||
@ -148,7 +148,7 @@ export default {
|
||||
{
|
||||
icon: require('./assets/images/lxwj-zxty.png'),
|
||||
title: '在线体验',
|
||||
link: 'https://101.37.77.25:1336/static/index.html'
|
||||
link: 'https://www.funasr.com:1336/static/index.html'
|
||||
},
|
||||
{
|
||||
icon: require('./assets/images/lxwj-az.png'),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user