update repo

This commit is contained in:
嘉渊 2023-06-15 15:56:19 +08:00
parent 0c4fbea66b
commit ca6b2e29fd
2 changed files with 65 additions and 100 deletions

View File

@ -1,46 +1,32 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# Build Model
model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
@ -144,16 +130,16 @@ class Text2Punc:
class Text2PuncVADRealtime:
def __init__(
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# Build Model
model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
@ -178,7 +164,7 @@ class Text2PuncVADRealtime:
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
if cache is not None and len(cache) > 0:
@ -215,7 +201,7 @@ class Text2PuncVADRealtime:
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
@ -226,7 +212,7 @@ class Text2PuncVADRealtime:
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
@ -235,11 +221,11 @@ class Text2PuncVADRealtime:
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
punctuations_np = punctuations.cpu().numpy()
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
@ -256,7 +242,7 @@ class Text2PuncVADRealtime:
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "" or sentence_punc_list[i] == "":
@ -267,5 +253,3 @@ class Text2PuncVADRealtime:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
return sentence_out, sentence_punc_list_out, cache_out

View File

@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@ -7,55 +7,36 @@ import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
def inference_punc(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
logging.basicConfig(
@ -73,11 +54,11 @@ def inference_punc(
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
):
results = []
split_size = 20
@ -121,20 +102,21 @@ def inference_punc(
return _forward
def inference_punc_vad_realtime(
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
#cache: list,
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
batch_size: int,
dtype: str,
ngpu: int,
seed: int,
num_workers: int,
log_level: Union[int, str],
# cache: list,
key_file: Optional[str],
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
@ -150,11 +132,11 @@ def inference_punc_vad_realtime(
text2punc = Text2PuncVADRealtime(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
):
results = []
split_size = 10
@ -177,7 +159,6 @@ def inference_punc_vad_realtime(
return _forward
def inference_launch(mode, **kwargs):
if mode == "punc":
return inference_punc(**kwargs)
@ -187,6 +168,7 @@ def inference_launch(mode, **kwargs):
logging.info("Unknown decoding mode: {}".format(mode))
return None
def get_parser():
parser = config_argparse.ArgumentParser(
description="Punctuation inference",
@ -269,6 +251,5 @@ def main(cmd=None):
return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
main()