Merge pull request #1250 from alibaba-damo-academy/funasr1.0

Funasr1.0
This commit is contained in:
Shi Xian 2024-01-16 11:34:04 +08:00 committed by GitHub
commit eba1fccfa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 581 additions and 519 deletions

View File

@ -95,9 +95,9 @@ model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
spk_model="cam++", spk_model_revision="v2.0.2")
res = model(input=f"{model.model_path}/example/asr_example.wav",
batch_size=64,
hotword='魔搭')
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
batch_size=64,
hotword='魔搭')
print(res)
```
Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
@ -124,7 +124,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
print(res)
```
Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
@ -135,7 +135,7 @@ from funasr import AutoModel
model = AutoModel(model="fsmn-vad", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
res = model(input=wav_file)
res = model.generate(input=wav_file)
print(res)
```
### Voice Activity Detection (Non-streaming)
@ -156,7 +156,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
if len(res[0]["value"]):
print(res)
```
@ -165,7 +165,7 @@ for i in range(total_chunk_num):
from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.2")
res = model(input="那今天的会就到这里吧 happy new year 明年见")
res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
print(res)
```
### Timestamp Prediction
@ -175,7 +175,7 @@ from funasr import AutoModel
model = AutoModel(model="fa-zh", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
res = model(input=(wav_file, text_file), data_type=("sound", "text"))
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
[//]: # (FunASR supports inference and fine-tuning of models trained on industrial datasets of tens of thousands of hours. For more details, please refer to ([modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)). It also supports training and fine-tuning of models on academic standard datasets. For more details, please refer to([egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)). The models include speech recognition (ASR), speech activity detection (VAD), punctuation recovery, language model, speaker verification, speaker separation, and multi-party conversation speech recognition. For a detailed list of models, please refer to the [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md):)
@ -229,10 +229,16 @@ The use of pretraining model is subject to [model license](./MODEL_LICENSE)
}
@inproceedings{gao22b_interspeech,
author={Zhifu Gao and ShiLiang Zhang and Ian McLoughlin and Zhijie Yan},
title={{Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition}},
title={Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition},
year=2022,
booktitle={Proc. Interspeech 2022},
pages={2063--2067},
doi={10.21437/Interspeech.2022-9996}
}
@inproceedings{shi2023seaco,
author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang},
title={SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability},
year={2023},
booktitle={ICASSP2024}
}
```

View File

@ -91,7 +91,7 @@ model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
spk_model="cam++", spk_model_revision="v2.0.2")
res = model(input=f"{model.model_path}/example/asr_example.wav",
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
batch_size=64,
hotword='魔搭')
print(res)
@ -121,7 +121,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
print(res)
```
@ -134,7 +134,7 @@ from funasr import AutoModel
model = AutoModel(model="fsmn-vad", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
res = model(input=wav_file)
res = model.generate(input=wav_file)
print(res)
```
@ -156,7 +156,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
if len(res[0]["value"]):
print(res)
```
@ -167,7 +167,7 @@ from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.2")
res = model(input="那今天的会就到这里吧 happy new year 明年见")
res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
print(res)
```
@ -179,7 +179,7 @@ model = AutoModel(model="fa-zh", model_revision="v2.0.0")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
res = model(input=(wav_file, text_file), data_type=("sound", "text"))
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
更多详细用法([示例](examples/industrial_data_pretraining)
@ -242,4 +242,10 @@ FunASR支持预训练或者进一步微调的模型进行服务部署。目前
pages={2063--2067},
doi={10.21437/Interspeech.2022-9996}
}
@article{shi2023seaco,
author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang},
title={{SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability}},
year=2023,
journal={arXiv preprint arXiv:2308.03266(accepted by ICASSP2024)},
}
```

View File

@ -6,14 +6,14 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.2",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.2",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model_revision="v2.0.2",
model_revision="v2.0.2",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.2",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.2",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model_revision="v2.0.2",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)

View File

@ -9,5 +9,5 @@ model = AutoModel(model="damo/speech_campplus_sv_zh-cn_16k-common",
model_revision="v2.0.2",
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)

View File

@ -7,6 +7,6 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 魔搭')
print(res)

View File

@ -7,7 +7,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
@ -15,5 +15,5 @@ from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)

View File

@ -12,7 +12,7 @@ vads = inputs.split("|")
rec_result_all = "outputs: "
cache = {}
for vad in vads:
rec_result = model(input=vad, cache=cache)
rec_result = model.generate(input=vad, cache=cache)
print(rec_result)
rec_result_all += rec_result[0]['text']

View File

@ -7,5 +7,5 @@ from funasr import AutoModel
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
print(res)

View File

@ -9,7 +9,7 @@ wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audi
chunk_size = 60000 # ms
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2")
res = model(input=wav_file, chunk_size=chunk_size, )
res = model.generate(input=wav_file, chunk_size=chunk_size, )
print(res)
@ -28,7 +28,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model(input=speech_chunk,
res = model.generate(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,

View File

@ -7,7 +7,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.2")
res = model(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
"欢迎大家来到魔搭社区进行体验"),
data_type=("sound", "text"),
batch_size=2,

View File

@ -15,6 +15,6 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
spk_model_revision="v2.0.2"
)
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 磨搭')
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 磨搭')
print(res)

View File

@ -7,7 +7,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
@ -18,5 +18,5 @@ frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-co
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
res = model(**fbank_dict)
res = model.generate(**fbank_dict)
print(res)

View File

@ -11,7 +11,7 @@ decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cr
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2")
cache = {}
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
@ -32,11 +32,11 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
res = model(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
res = model.generate(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)

View File

@ -15,6 +15,6 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co
spk_model_revision="v2.0.2",
)
res = model(input=f"{model.model_path}/example/asr_example.wav",
hotword='达摩院 魔搭')
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
hotword='达摩院 魔搭')
print(res)

View File

@ -30,4 +30,5 @@ def import_submodules(package, recursive=True):
import_submodules(__name__)
from funasr.bin.inference import AutoModel, AutoFrontend
from funasr.auto.auto_model import AutoModel
from funasr.auto.auto_frontend import AutoFrontend

0
funasr/auto/__init__.py Normal file
View File

View File

@ -0,0 +1,95 @@
import json
import time
import torch
import hydra
import random
import string
import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
from funasr.auto.auto_model import prepare_data_iterator
class AutoFrontend:
def __init__(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend
if "frontend" in kwargs:
del kwargs["frontend"]
self.kwargs = kwargs
def __call__(self, input, input_len=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu")
if device == "cpu":
batch_size = 1
meta_data = {}
result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend, **kwargs)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
result_list.append(batch)
pbar.update(1)
description = (
f"{meta_data}, "
)
pbar.set_description(description)
time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list

416
funasr/auto/auto_model.py Normal file
View File

@ -0,0 +1,416 @@
import json
import time
import torch
import hydra
import random
import string
import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
:param input:
:param input_len:
:param data_type:
:param frontend:
:return:
"""
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding='utf-8') as fin:
for line in fin:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip())
data = lines["source"]
key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines)>1 else lines[0]
key = lines[0] if len(lines)>1 else key
data_list.append(data)
key_list.append(key)
else:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i)
data_list_tmp.append(data_list_i)
data_list = []
for item in zip(*data_list_tmp):
data_list.append(item)
else:
# [audio sample point, fbank, text]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
if key is None:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
return key_list, data_list
class AutoModel:
def __init__(self, **kwargs):
tables.print()
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
vad_model = kwargs.get("vad_model", None)
vad_kwargs = kwargs.get("vad_model_revision", None)
if vad_model is not None:
logging.info("Building VAD model.")
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
# if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None)
punc_kwargs = kwargs.get("punc_model_revision", None)
if punc_model is not None:
logging.info("Building punc model.")
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = kwargs.get("spk_model_revision", None)
if spk_model is not None:
logging.info("Building SPK model.")
spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend()
spk_mode = kwargs.get("spk_mode", 'punc_segment')
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
self.spk_mode = spk_mode
self.preset_spk_num = kwargs.get("preset_spk_num", None)
if self.preset_spk_num:
logging.warning("Using preset speaker number: {}".format(self.preset_spk_num))
logging.warning("Many to print when using speaker model...")
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs["model_path"]
def build_model(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
# kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
torch.set_num_threads(kwargs.get("ncpu"))
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
vocab_size = len(tokenizer.token_list)
else:
vocab_size = -1
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
init_param=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
)
return model, kwargs
def __call__(self, *args, **cfg):
kwargs = self.kwargs
kwargs.update(cfg)
res = self.model(*args, kwargs)
return res
def generate(self, input, input_len=None, **cfg):
if self.vad_model is None:
return self.inference(input, input_len=input_len, **cfg)
else:
return self.inference_with_vad(input, input_len=input_len, **cfg)
def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
time_speech_total = 0.0
time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
time1 = time.perf_counter()
with torch.no_grad():
results, meta_data = model.inference(**batch, **kwargs)
time2 = time.perf_counter()
asr_result_list.extend(results)
pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = (
f"{speed_stats}, "
)
pbar.set_description(description)
time_speech_total += batch_data_time
time_escape_total += time_escape
pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
def inference_with_vad(self, input, input_len=None, **cfg):
# step.1: compute the vad model
self.vad_kwargs.update(cfg)
beg_vad = time.time()
res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
# step.2 compute asr model
model = self.model
kwargs = self.kwargs
kwargs.update(cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
results_ret_list = []
time_speech_total_all_samples = 0.0
beg_total = time.time()
pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
for i in range(len(res)):
key = res[i]["key"]
vadsegments = res[i]["value"]
input_i = data_list[i]
speech = load_audio_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech_lengths = len(speech)
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
batch_size_ms_cum = 0
beg_idx = 0
beg_asr_total = time.time()
time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample
for j, _ in enumerate(range(0, n)):
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
continue
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
speech_j[_b]]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
continue
results_sorted.extend(results)
pbar_total.update(1)
end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
restored_data[index] = results_sorted[j]
result = {}
# results combine for texts, timestamps, speaker embeddings and others
# TODO: rewrite for clean code
for j in range(n):
for k, v in restored_data[j].items():
if k.startswith("timestamp"):
if k not in result:
result[k] = []
for t in restored_data[j][k]:
t[0] += vadsegments[j][0]
t[1] += vadsegments[j][0]
result[k].extend(restored_data[j][k])
elif k == 'spk_embedding':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
elif k == 'text':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += " " + restored_data[j][k]
else:
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
# step.3 compute punc model
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
result["text_with_punc"] = punc_res[0]["text"]
# speaker embedding cluster after resorted
if self.spk_model is not None:
all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding, oracle_num=self.preset_spk_num)
del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
if self.spk_mode == 'vad_segment':
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['text'],
"timestamp": res['timestamp']})
else: # punc_segment
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['text'])
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
return results_ret_list

View File

@ -0,0 +1,8 @@
class AutoTokenizer:
"""
Undo
"""
def __init__(self):
pass

View File

@ -1,88 +1,10 @@
import json
import time
import torch
import hydra
import random
import string
import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
from funasr.auto.auto_model import AutoModel
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
:param input:
:param input_len:
:param data_type:
:param frontend:
:return:
"""
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding='utf-8') as fin:
for line in fin:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip())
data = lines["source"]
key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines)>1 else lines[0]
key = lines[0] if len(lines)>1 else key
data_list.append(data)
key_list.append(key)
else:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i)
data_list_tmp.append(data_list_i)
data_list = []
for item in zip(*data_list_tmp):
data_list.append(item)
else:
# [audio sample point, fbank, text]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
if key is None:
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
return key_list, data_list
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
@ -101,401 +23,9 @@ def main_hydra(cfg: DictConfig):
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
model = AutoModel(**kwargs)
res = model(input=kwargs["input"])
res = model.generate(input=kwargs["input"])
print(res)
class AutoModel:
def __init__(self, **kwargs):
tables.print()
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
vad_model = kwargs.get("vad_model", None)
vad_kwargs = kwargs.get("vad_model_revision", None)
if vad_model is not None:
logging.info("Building VAD model.")
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
# if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None)
punc_kwargs = kwargs.get("punc_model_revision", None)
if punc_model is not None:
logging.info("Building punc model.")
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = kwargs.get("spk_model_revision", None)
if spk_model is not None:
logging.info("Building SPK model.")
spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend()
spk_mode = kwargs.get("spk_mode", 'punc_segment')
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
self.spk_mode = spk_mode
self.preset_spk_num = kwargs.get("preset_spk_num", None)
if self.preset_spk_num:
logging.warning("Using preset speaker number: {}".format(self.preset_spk_num))
logging.warning("Many to print when using speaker model...")
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
self.vad_kwargs = vad_kwargs
self.punc_model = punc_model
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs["model_path"]
def build_model(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
# kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
torch.set_num_threads(kwargs.get("ncpu"))
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
vocab_size = len(tokenizer.token_list)
else:
vocab_size = -1
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
logging.info(f"Loading pretrained params from {init_param}")
load_pretrained_model(
model=model,
init_param=init_param,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
)
return model, kwargs
def __call__(self, input, input_len=None, **cfg):
if self.vad_model is None:
return self.generate(input, input_len=input_len, **cfg)
else:
return self.generate_with_vad(input, input_len=input_len, **cfg)
def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
time_speech_total = 0.0
time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
time1 = time.perf_counter()
with torch.no_grad():
results, meta_data = model.inference(**batch, **kwargs)
time2 = time.perf_counter()
asr_result_list.extend(results)
pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
speed_stats["forward"] = f"{time_escape:0.3f}"
speed_stats["batch_size"] = f"{len(results)}"
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = (
f"{speed_stats}, "
)
pbar.set_description(description)
time_speech_total += batch_data_time
time_escape_total += time_escape
pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
def generate_with_vad(self, input, input_len=None, **cfg):
# step.1: compute the vad model
self.vad_kwargs.update(cfg)
beg_vad = time.time()
res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
# step.2 compute asr model
model = self.model
kwargs = self.kwargs
kwargs.update(cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
results_ret_list = []
time_speech_total_all_samples = 0.0
beg_total = time.time()
pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
for i in range(len(res)):
key = res[i]["key"]
vadsegments = res[i]["value"]
input_i = data_list[i]
speech = load_audio_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech_lengths = len(speech)
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
batch_size_ms_cum = 0
beg_idx = 0
beg_asr_total = time.time()
time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample
for j, _ in enumerate(range(0, n)):
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
continue
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
speech_j[_b]]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
continue
results_sorted.extend(results)
pbar_total.update(1)
end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
restored_data[index] = results_sorted[j]
result = {}
# results combine for texts, timestamps, speaker embeddings and others
# TODO: rewrite for clean code
for j in range(n):
for k, v in restored_data[j].items():
if k.startswith("timestamp"):
if k not in result:
result[k] = []
for t in restored_data[j][k]:
t[0] += vadsegments[j][0]
t[1] += vadsegments[j][0]
result[k].extend(restored_data[j][k])
elif k == 'spk_embedding':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
elif k == 'text':
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += " " + restored_data[j][k]
else:
if k not in result:
result[k] = restored_data[j][k]
else:
result[k] += restored_data[j][k]
# step.3 compute punc model
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
result["text_with_punc"] = punc_res[0]["text"]
# speaker embedding cluster after resorted
if self.spk_model is not None:
all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding, oracle_num=self.preset_spk_num)
del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
if self.spk_mode == 'vad_segment':
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['text'],
"timestamp": res['timestamp']})
else: # punc_segment
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['text'])
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
return results_ret_list
class AutoFrontend:
def __init__(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(**kwargs)
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend
if "frontend" in kwargs:
del kwargs["frontend"]
self.kwargs = kwargs
def __call__(self, input, input_len=None, kwargs=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu")
if device == "cpu":
batch_size = 1
meta_data = {}
result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend, **kwargs)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
result_list.append(batch)
pbar.update(1)
description = (
f"{meta_data}, "
)
pbar.set_description(description)
time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list
if __name__ == '__main__':
main_hydra()