torchscripts

This commit is contained in:
游雁 2023-03-02 19:34:46 +08:00
parent 458296d105
commit 905a1f2585
10 changed files with 877 additions and 0 deletions

View File

@ -0,0 +1,70 @@
## Using paraformer with libtorch
### Introduction
- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
### Steps:
1. Export the model.
- Command: (`Tips`: torch >= 1.11.0 is required.)
```shell
python -m funasr.export.export_model [model_name] [export_dir] [true]
```
`model_name`: the model is to export.
`export_dir`: the dir where the onnx is export.
More details ref to ([export docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
- `e.g.`, Export model from modelscope
```shell
python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
```
- `e.g.`, Export model from local path, the model'name must be `model.pb`.
```shell
python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
```
2. Install the `torch_paraformer`.
- Build the torch_paraformer `whl`
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/libtorch
python setup.py bdist_wheel
```
- Install the build `whl`
```bash
pip install dist/torch_paraformer-0.0.1-py3-none-any.whl
```
3. Run the demo.
- Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`.
- Input: wav formt file, support formats: `str, np.ndarray, List[str]`
- Output: `List[str]`: recognition result.
- Example:
```python
from torch_paraformer import Paraformer
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1)
wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
result = model(wav_path)
print(result)
```
## Speed
EnvironmentIntel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz
Test [wav, 5.53s, 100 times avg.](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav)
| Backend | RTF |
|:-------:|:-----------------:|
| Pytorch | 0.110 |
| Onnx | 0.038 |
## Acknowledge

View File

@ -0,0 +1,11 @@
from torch_paraformer import Paraformer
model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1)
wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav']
result = model(wav_path)
print(result)

View File

@ -0,0 +1,43 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
import setuptools
def get_readme():
root_dir = Path(__file__).resolve().parent
readme_path = str(root_dir / 'README.md')
print(readme_path)
with open(readme_path, 'r', encoding='utf-8') as f:
readme = f.read()
return readme
setuptools.setup(
name='torch_paraformer',
version='0.0.1',
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
author="Speech Lab, Alibaba Group, China",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license="The MIT License",
long_description=get_readme(),
long_description_content_type='text/markdown',
include_package_data=True,
install_requires=["librosa", "onnxruntime>=1.7.0",
"scipy", "numpy>=1.19.3",
"typeguard", "kaldi-native-fbank",
"PyYAML>=5.1.2"],
packages=['torch_paraformer'],
keywords=[
'funasr,paraformer'
],
classifiers=[
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
],
)

View File

@ -0,0 +1,2 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer

View File

@ -0,0 +1,155 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis,
TokenIDConverter, get_logger,
read_yaml)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
logging = get_logger()
import torch
class Paraformer():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
model_file = os.path.join(model_dir, 'model.onnx')
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)
self.converter = TokenIDConverter(config['token_list'])
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(
cmvn_file=cmvn_file,
**config['frontend_conf']
)
self.ort_infer = torch.jit.load(model_file)
self.batch_size = batch_size
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
res = {}
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
outputs = self.infer(feats, feats_len)
outs = outputs[0], outputs[1]
am_scores, valid_token_lens = outs[0], outs[1]
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_cif_peak = outputs[2], outputs[3]
else:
us_alphas, us_cif_peak = None, None
except:
#logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = ['']
else:
am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy()
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
res['preds'] = preds
if us_cif_peak is not None:
us_alphas, us_cif_peak = us_alphas.cpu().numpy(), us_cif_peak.cpu().numpy()
timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(raw_token), log=False)
res['timestamp'] = timestamp
asr_res.append(res)
return asr_res
def load_data(self,
wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(
f'The type of {wav_content} is not in [str, np.ndarray, list]')
def extract_feat(self,
waveform_list: List[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, 'constant', constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: np.ndarray,
feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)]
def decode_one(self,
am_score: np.ndarray,
valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
# token = token[:valid_token_num-1]
texts = sentence_postprocess(token)
text = texts[0]
# text = self.tokenizer.tokens2text(token)
return text, token

View File

@ -0,0 +1,191 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
from typeguard import check_argument_types
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class WavFrontend():
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
**kwargs,
) -> None:
check_argument_types()
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs
opts.frame_opts.dither = dither
opts.frame_opts.window_type = window
opts.frame_opts.frame_shift_ms = float(frame_shift)
opts.frame_opts.frame_length_ms = float(frame_length)
opts.mel_opts.num_bins = n_mels
opts.energy_floor = 0
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
self.opts = opts
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
if self.cmvn_file:
self.cmvn = self.load_cmvn()
self.fbank_fn = None
self.fbank_beg_idx = 0
self.reset_status()
def fbank(self,
waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def fbank_online(self,
waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
# self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(self.fbank_beg_idx, frames):
mat[i, :] = self.fbank_fn.get_frame(i)
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def reset_status(self):
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_beg_idx = 0
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
if self.cmvn_file:
feat = self.apply_cmvn(feat)
feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append(
(inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
else:
# process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = inputs[i * lfr_n:].reshape(-1)
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
"""
Apply CMVN with mvn data
"""
frame, dim = inputs.shape
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars
return inputs
def load_cmvn(self,) -> np.ndarray:
with open(self.cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float64)
vars = np.array(vars_list).astype(np.float64)
cmvn = np.array([means, vars])
return cmvn
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def test():
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
import librosa
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
config = read_yaml(config_file)
waveform, _ = librosa.load(path, sr=None)
frontend = WavFrontend(
cmvn_file=cmvn_file,
**config['frontend_conf'],
)
speech, _ = frontend.fbank_online(waveform) #1d, (sample,), numpy
feat, feat_len = frontend.lfr_cmvn(speech) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
frontend.reset_status() # clear cache
return feat, feat_len
if __name__ == '__main__':
test()

View File

@ -0,0 +1,240 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
if num + 1 < words_size and words[
num + 1] == ' ' and num + 2 < words_size and len(
words[num +
2]) == 1 and words[num +
2].encode('utf-8').isalpha():
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == ' ':
num += 1
if num < words_size and len(
words[num]) == 1 and words[num].encode(
'utf-8').isalpha():
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
for num in range(words_size):
if words[num] == ' ':
ts_nums.append(ts_index)
else:
ts_nums.append(ts_index)
ts_index += 1
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
word_lists.append(words[num].upper())
num += 1
while num < words_size:
if num in abbr_end:
word_lists.append(words[num].upper())
last_num = num
break
else:
if words[num].encode('utf-8').isalpha():
word_lists.append(words[num].upper())
num += 1
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
else:
word_lists.append(words[num])
if time_stamp is not None and words[num] != ' ':
begin = time_stamp[ts_nums[num]][0]
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
begin = end
if time_stamp is not None:
return word_lists, ts_lists
else:
return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
middle_lists = []
word_lists = []
word_item = ''
ts_lists = []
# wash words lists
for i in words:
word = ''
if isinstance(i, str):
word = i
else:
word = i.decode('utf-8')
if word in ['<s>', '</s>', '<unk>']:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(' ', ''))
if time_stamp is not None:
ts_lists = time_stamp
# all alpha characters
elif isAllAlpha(middle_lists):
ts_flag = True
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ''
if '@@' in ch:
word = ch.replace('@@', '')
word_item += word
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
# mix characters
else:
alpha_blank = False
ts_flag = True
begin = -1
end = -1
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ''
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
if time_stamp is not None:
ts_flag = True
ts_lists.append([begin, end])
begin = end
elif '@@' in ch:
word = ch.replace('@@', '')
word_item += word
alpha_blank = False
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
alpha_blank = True
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
else:
raise ValueError('invalid character: {}'.format(ch))
if time_stamp is not None:
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
real_word_lists.append(ch)
sentence = ' '.join(real_word_lists).strip()
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
return sentence, real_word_lists

View File

@ -0,0 +1,165 @@
# -*- encoding: utf-8 -*-
import functools
import logging
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
import yaml
from typeguard import check_argument_types
import warnings
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
check_argument_types()
# self.token_list = self.load_token(token_path)
self.token_list = token_list
self.unk_symbol = token_list[-1]
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self,
integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise TokenIDConverterError(
f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
token2id = {v: i for i, v in enumerate(self.token_list)}
if self.unk_symbol not in token2id:
raise TokenIDConverterError(
f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
)
unk_id = token2id[self.unk_symbol]
return [token2id.get(i, unk_id) for i in tokens]
class CharTokenizer():
def __init__(
self,
symbol_value: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
check_argument_types()
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
@staticmethod
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
if value is None:
return set()
if isinstance(value, Iterable[str]):
return set(value)
file_path = Path(value)
if not file_path.exists():
logging.warning("%s doesn't exist.", file_path)
return set()
with file_path.open("r", encoding="utf-8") as f:
return set(line.rstrip() for line in f)
def text2tokens(self, line: Union[str, list]) -> List[str]:
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w):]
break
else:
t = line[0]
if t == " ":
t = "<space>"
tokens.append(t)
line = line[1:]
return tokens
def tokens2text(self, tokens: Iterable[str]) -> str:
tokens = [t if t != self.space_symbol else " " for t in tokens]
return "".join(tokens)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f'space_symbol="{self.space_symbol}"'
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
f")"
)
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: np.ndarray
score: Union[float, np.ndarray] = 0
scores: Dict[str, Union[float, np.ndarray]] = dict()
states: Dict[str, Any] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f'The {yaml_path} does not exist.')
with open(str(yaml_path), 'rb') as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
@functools.lru_cache()
def get_logger(name='torch_paraformer'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
Args:
name (str): Logger name.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
logger_initialized[name] = True
logger.propagate = False
return logger