punc vad realtime infer

This commit is contained in:
mengzhe.cmz 2023-03-07 18:25:47 +08:00
parent c8c3f04a91
commit 4de9fe6cf6
4 changed files with 365 additions and 1 deletions

View File

@ -0,0 +1,26 @@
##################text二进制数据#####################
inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_pipline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
model_revision="v1.0.0",
output_dir="./tmp/"
)
vads = inputs.split("|")
cache_out = []
rec_result_all="outputs:"
for vad in vads:
rec_result = inference_pipline(text_in=vad, cache=cache_out)
#print(rec_result)
cache_out = rec_result['cache']
rec_result_all += rec_result['text']
print(rec_result_all)

View File

@ -15,7 +15,7 @@ from modelscope.utils.constant import Tasks
inference_pipline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
model_revision="v1.1.6",
model_revision="v1.1.7",
output_dir="./tmp/"
)

View File

@ -75,6 +75,9 @@ def inference_launch(mode, **kwargs):
if mode == "punc":
from funasr.bin.punctuation_infer import inference_modelscope
return inference_modelscope(**kwargs)
if mode == "punc_VadRealtime":
from funasr.bin.punctuation_infer_vadrealtime import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None

View File

@ -0,0 +1,335 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# Build Model
model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
# logging.info(f"Model:\n{model}")
self.punc_list = train_args.punc_list
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
self.punc_list[i] = ""
elif self.punc_list[i] == "?":
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=train_args.token_type,
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
print("start decoding!!!")
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
if cache is not None and len(cache) > 0:
precache = "".join(cache)
else:
precache = ""
data = {"text": precache + text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
sentence_punc_list = []
sentence_words_list= []
cache_pop_trigger_limit = 200
skip_num = 0
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
"vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
punctuations_np = punctuations.cpu().numpy()
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
for i in range(0, len(sentence_words_list)):
if i > 0:
if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
sentence_words_list[i] = " " + sentence_words_list[i]
if skip_num < len(cache):
skip_num += 1
else:
words_with_punc.append(sentence_words_list[i])
if skip_num >= len(cache):
sentence_punc_list_out.append(sentence_punc_list[i])
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "" or sentence_punc_list[i] == "":
sentenceEnd = i
break
cache_out = sentence_words_list[sentenceEnd + 1 :]
if sentence_out[-1] in self.punc_list:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
return sentence_out, sentence_punc_list_out, cache_out
def inference(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
output_dir: str,
log_level: Union[int, str],
train_config: Optional[str],
model_file: Optional[str],
key_file: Optional[str] = None,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[List[Any], bytes, str] = None,
cache: List[Any] = None,
param_dict: dict = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
seed=seed,
num_workers=num_workers,
log_level=log_level,
key_file=key_file,
train_config=train_config,
model_file=model_file,
param_dict=param_dict,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs, cache)
def inference_modelscope(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
#cache: list,
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
):
results = []
split_size = 10
if raw_inputs != None:
line = raw_inputs.strip()
key = "demo"
if line == "":
item = {'key': key, 'value': ""}
results.append(item)
return results
#import pdb;pdb.set_trace()
result, _, cache = text2punc(line, cache)
item = {'key': key, 'value': result, 'cache': cache}
results.append(item)
return results
for inference_text, _, _ in data_path_and_name_and_type:
with open(inference_text, "r", encoding="utf-8") as fin:
for line in fin:
line = line.strip()
segs = line.split("\t")
if len(segs) != 2:
continue
key = segs[0]
if len(segs[1]) == 0:
continue
result, _ = text2punc(segs[1])
item = {'key': key, 'value': result}
results.append(item)
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path != None:
output_file_name = "infer.out"
Path(output_path).mkdir(parents=True, exist_ok=True)
output_file_path = (Path(output_path) / output_file_name).absolute()
with open(output_file_path, "w", encoding="utf-8") as fout:
for item_i in results:
key_out = item_i["key"]
value_out = item_i["value"]
fout.write(f"{key_out}\t{value_out}\n")
return results
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="Punctuation inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=False)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group = parser.add_argument_group("Input data related")
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
group.add_argument("--raw_inputs", type=str, required=False)
group.add_argument("--cache", type=list, required=False)
group.add_argument("--param_dict", type=dict, required=False)
group.add_argument("--key_file", type=str_or_none)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
# kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()