Merge branch 'main' into dev

This commit is contained in:
lzr265946 2023-02-03 14:11:22 +08:00
commit 1d97d628f2
14 changed files with 219 additions and 384 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 187 KiB

After

Width:  |  Height:  |  Size: 180 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 183 KiB

After

Width:  |  Height:  |  Size: 186 KiB

View File

@ -368,7 +368,7 @@ class Speech2Text:
# except TooShortUttError as e:
# logging.warning(f"Utterance {keys} {e}")
# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
# results = [[" ", ["<space>"], [2], hyp]] * nbest
# results = [[" ", ["sil"], [2], hyp]] * nbest
#
# # Only supporting batch_size==1
# key = keys[0]
@ -577,7 +577,7 @@ def inference_modelscope(
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]

View File

@ -227,6 +227,8 @@ class Speech2Text:
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
@ -394,7 +396,7 @@ class Speech2Text:
# results = speech2text(**batch)
# if len(results) < 1:
# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
# results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
# results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
# time_end = time.time()
# forward_time = time_end - time_beg
# lfr_factor = results[0][-1]
@ -623,7 +625,7 @@ def inference_modelscope(
results = speech2text(**batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]

View File

@ -410,7 +410,7 @@ def inference(
results = speech2text(**batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python3
import json
import argparse
import logging
import sys
import time
import json
from pathlib import Path
from typing import Optional
from typing import Sequence
@ -38,10 +39,10 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.tasks.punctuation import PunctuationTask
from funasr.bin.punctuation_infer import Text2Punc
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
header_colors = '\033[95m'
end_colors = '\033[0m'
@ -236,6 +237,8 @@ class Speech2Text:
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
@ -604,7 +607,7 @@ def inference_modelscope(
results = speech2text(**batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], 0, 1, 6]] * nbest
results = [[" ", ["sil"], [2], 0, 1, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]
@ -680,102 +683,6 @@ def inference_modelscope(
return asr_result_list
return _forward
def Text2Punc(
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# 2. Build Model
model, train_args = PunctuationTask.build_model_from_file(
train_config, model_file, device)
# Wrape model to make model.nll() data-parallel
wrapped_model = ForwardAdaptor(model, "inference")
wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
# logging.info(f"Model:\n{model}")
punc_list = train_args.punc_list
period = 0
for i in range(len(punc_list)):
if punc_list[i] == ",":
punc_list[i] = ""
elif punc_list[i] == "?":
punc_list[i] = ""
elif punc_list[i] == "":
period = i
preprocessor = CommonPreprocessor(
train=False,
token_type="word",
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
print("start decoding!!!")
def _forward(words, split_size = 20):
cache_sent = []
mini_sentences = split_to_mini_sentence(words, split_size)
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
data = {"text": " ".join(mini_sentence)}
batch = preprocessor(data=data, uid="12938712838719")
batch["text_lengths"] = torch.from_numpy(np.array([len(batch["text"])], dtype='int32'))
batch["text"] = torch.from_numpy(batch["text"])
# Extend one dimension to fake a batch dim.
batch["text"] = torch.unsqueeze(batch["text"], 0)
batch = to_device(batch, device)
y, _ = wrapped_model(**batch)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if punc_list[punctuations[i]] == "" or punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = period
cache_sent = mini_sentence[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
# if len(punctuations) == 0:
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if punc_list[punctuations[i]] != "_":
words_with_punc.append(punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
return new_mini_sentence, new_mini_sentence_punc
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",

View File

@ -391,7 +391,7 @@ class Speech2Text:
# except TooShortUttError as e:
# logging.warning(f"Utterance {keys} {e}")
# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
# results = [[" ", ["<space>"], [2], hyp]] * nbest
# results = [[" ", ["sil"], [2], hyp]] * nbest
#
# # Only supporting batch_size==1
# key = keys[0]
@ -618,7 +618,7 @@ def inference_modelscope(
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]

View File

@ -59,26 +59,18 @@ def get_parser():
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
action="append",
required=False
)
group.add_argument(
"--raw_inputs",
type=str,
required=False
)
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
group.add_argument("--raw_inputs", type=str, required=False)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--cache", type=list, required=False)
group.add_argument("--param_dict", type=dict, required=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
group.add_argument("--mode", type=str, default="punc")
return parser
def inference_launch(mode, **kwargs):
if mode == "punc":
from funasr.bin.punctuation_infer import inference_modelscope

View File

@ -3,33 +3,141 @@ import argparse
import logging
from pathlib import Path
import sys
import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
self,
train_config: Optional[str],
model_file: Optional[str],
device: str = "cpu",
dtype: str = "float32",
):
# Build Model
model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
# logging.info(f"Model:\n{model}")
self.punc_list = train_args.punc_list
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
self.punc_list[i] = ""
elif self.punc_list[i] == "?":
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=train_args.token_type,
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
print("start decoding!!!")
@torch.no_grad()
def __call__(self, text: Union[list, str], split_size=20):
data = {"text": text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if self.punc_list[punctuations[i]] == "" or self.punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
# if len(punctuations) == 0:
# continue
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
words_with_punc.append(self.punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "":
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
def inference(
@ -45,12 +153,12 @@ def inference(
key_file: Optional[str] = None,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[List[Any], bytes, str] = None,
cache: List[Any] = None,
param_dict: dict = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
raw_inputs=raw_inputs,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
@ -60,6 +168,7 @@ def inference(
key_file=key_file,
train_config=train_config,
model_file=model_file,
param_dict=param_dict,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@ -76,6 +185,7 @@ def inference_modelscope(
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
@ -91,41 +201,14 @@ def inference_modelscope(
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build Model
model, train_args = PunctuationTask.build_model_from_file(
train_config, model_file, device)
# Wrape model to make model.nll() data-parallel
wrapped_model = ForwardAdaptor(model, "inference")
wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
logging.info(f"Model:\n{model}")
punc_list = train_args.punc_list
period = 0
for i in range(len(punc_list)):
if punc_list[i] == ",":
punc_list[i] = ""
elif punc_list[i] == "?":
punc_list[i] = ""
elif punc_list[i] == "":
period = i
preprocessor = CommonPreprocessor(
train=False,
token_type="word",
token_list=train_args.token_list,
bpemodel=train_args.bpemodel,
text_cleaner=train_args.cleaner,
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
print("start decoding!!!")
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
cache: List[Any] = None,
param_dict: dict = None,
):
results = []
split_size = 20
@ -133,77 +216,14 @@ def inference_modelscope(
if raw_inputs != None:
line = raw_inputs.strip()
key = "demo"
if line=="":
if line == "":
item = {'key': key, 'value': ""}
results.append(item)
return results
cache_sent = []
words = split_words(line)
new_mini_sentence = ""
new_mini_sentence_punc = ""
cache_pop_trigger_limit = 200
mini_sentences = split_to_mini_sentence(words, split_size)
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
data = {"text": " ".join(mini_sentence)}
batch = preprocessor(data=data, uid="12938712838719")
batch["text_lengths"] = torch.from_numpy(
np.array([len(batch["text"])], dtype='int32'))
batch["text"] = torch.from_numpy(batch["text"])
# Extend one dimension to fake a batch dim.
batch["text"] = torch.unsqueeze(batch["text"], 0)
batch = to_device(batch, device)
y, _ = wrapped_model(**batch)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences)-1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations)-2,1,-1):
if punc_list[punctuations[i]] == "" or punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = period
cache_sent = mini_sentence[sentenceEnd+1:]
mini_sentence = mini_sentence[0:sentenceEnd+1]
punctuations = punctuations[0:sentenceEnd+1]
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
words_with_punc = []
for i in range(len(mini_sentence)):
if i>0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
mini_sentence[i] = " "+ mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if punc_list[punctuations[i]] != "_":
words_with_punc.append(punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences)-1:
if new_mini_sentence[-1]=="" or new_mini_sentence[-1]=="":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
elif new_mini_sentence[-1]!="" and new_mini_sentence[-1]!="":
new_mini_sentence_out=new_mini_sentence+""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
item = {'key': key, 'value': new_mini_sentence_out}
results.append(item)
result, _ = text2punc(line)
item = {'key': key, 'value': result}
results.append(item)
print(results)
return results
for inference_text, _, _ in data_path_and_name_and_type:
@ -216,72 +236,9 @@ def inference_modelscope(
key = segs[0]
if len(segs[1]) == 0:
continue
cache_sent = []
words = split_words(segs[1])
new_mini_sentence = ""
new_mini_sentence_punc = ""
cache_pop_trigger_limit = 200
mini_sentences = split_to_mini_sentence(words, split_size)
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
data = {"text": " ".join(mini_sentence)}
batch = preprocessor(data=data, uid="12938712838719")
batch["text_lengths"] = torch.from_numpy(
np.array([len(batch["text"])], dtype='int32'))
batch["text"] = torch.from_numpy(batch["text"])
# Extend one dimension to fake a batch dim.
batch["text"] = torch.unsqueeze(batch["text"], 0)
batch = to_device(batch, device)
y, _ = wrapped_model(**batch)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
punctuations = indices
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences)-1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations)-2,1,-1):
if punc_list[punctuations[i]] == "" or punc_list[punctuations[i]] == "":
sentenceEnd = i
break
if last_comma_index < 0 and punc_list[punctuations[i]] == "":
last_comma_index = i
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = period
cache_sent = mini_sentence[sentenceEnd+1:]
mini_sentence = mini_sentence[0:sentenceEnd+1]
punctuations = punctuations[0:sentenceEnd+1]
punctuations_np = punctuations.cpu().numpy()
new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
words_with_punc = []
for i in range(len(mini_sentence)):
if i>0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
mini_sentence[i] = " "+ mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if punc_list[punctuations[i]] != "_":
words_with_punc.append(punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences)-1:
if new_mini_sentence[-1]=="" or new_mini_sentence[-1]=="":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
elif new_mini_sentence[-1]!="" and new_mini_sentence[-1]!="":
new_mini_sentence_out=new_mini_sentence+""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
item = {'key': key, 'value': new_mini_sentence_out}
results.append(item)
result, _ = text2punc(segs[1])
item = {'key': key, 'value': result}
results.append(item)
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path != None:
output_file_name = "infer.out"
@ -293,6 +250,7 @@ def inference_modelscope(
value_out = item_i["value"]
fout.write(f"{key_out}\t{value_out}\n")
return results
return _forward
@ -338,20 +296,12 @@ def get_parser():
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
action="append",
required=False
)
group.add_argument(
"--raw_inputs",
type=str,
required=False
)
group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
group.add_argument("--raw_inputs", type=str, required=False)
group.add_argument("--cache", type=list, required=False)
group.add_argument("--param_dict", type=dict, required=False)
group.add_argument("--key_file", type=str_or_none)
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
group.add_argument("--model_file", type=str)
@ -364,11 +314,9 @@ def main(cmd=None):
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
# kwargs.pop("config", None)
# kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@ -23,7 +23,5 @@ class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""
@abstractmethod
def forward(
self, input: torch.Tensor, hidden: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@ -13,6 +13,7 @@ from funasr.train.abs_espnet_model import AbsESPnetModel
class ESPnetPunctuationModel(AbsESPnetModel):
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
@ -43,8 +44,8 @@ class ESPnetPunctuationModel(AbsESPnetModel):
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
punc = punc[:, : text_lengths.max()]
text = text[:, :text_lengths.max()]
punc = punc[:, :text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
@ -63,9 +64,11 @@ class ESPnetPunctuationModel(AbsESPnetModel):
# 3. Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
_, indices = y.view(-1, y.shape[-1]).topk(1,dim=1)
_, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
from sklearn.metrics import f1_score
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(), indices.squeeze(-1).detach().cpu().numpy(), average='micro')
f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
indices.squeeze(-1).detach().cpu().numpy(),
average='micro')
nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
return nll, text_lengths
else:
@ -82,14 +85,12 @@ class ESPnetPunctuationModel(AbsESPnetModel):
nll = nll.view(batch_size, -1)
return nll, text_lengths
def batchify_nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
batch_size: int = 100
) -> Tuple[torch.Tensor, torch.Tensor]:
def batchify_nll(self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model
To avoid OOM, this fuction seperate the input into batches.
@ -117,9 +118,7 @@ class ESPnetPunctuationModel(AbsESPnetModel):
batch_punc = punc[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = self.nll(
batch_text, batch_punc, batch_text_lengths, max_length=max_length
)
batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
@ -131,21 +130,19 @@ class ESPnetPunctuationModel(AbsESPnetModel):
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor, punc_lengths: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(
self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {}
def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:

View File

@ -14,6 +14,7 @@ from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
def __init__(
self,
vocab_size: int,
@ -28,7 +29,7 @@ class TargetDelayTransformer(AbsPunctuation):
):
super().__init__()
if pos_enc == "sinusoidal":
# pos_enc_class = PositionalEncoding
# pos_enc_class = PositionalEncoding
pos_enc_class = SinusoidalPositionEncoder
elif pos_enc is None:
@ -47,17 +48,17 @@ class TargetDelayTransformer(AbsPunctuation):
num_blocks=layer,
dropout_rate=dropout_rate,
input_layer="pe",
# pos_enc_class=pos_enc_class,
# pos_enc_class=pos_enc_class,
padding_idx=0,
)
self.decoder = nn.Linear(att_unit, punc_size)
# def _target_mask(self, ys_in_pad):
# ys_mask = ys_in_pad != 0
# m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
# return ys_mask.unsqueeze(-2) & m
def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
@ -67,14 +68,12 @@ class TargetDelayTransformer(AbsPunctuation):
"""
x = self.embed(input)
# mask = self._target_mask(input)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
def score(
self, y: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
@ -89,16 +88,12 @@ class TargetDelayTransformer(AbsPunctuation):
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(
self.embed(y), self._target_mask(y), cache=state
)
h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
@ -120,15 +115,10 @@ class TargetDelayTransformer(AbsPunctuation):
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
# batch decoding
h, _, states = self.encoder.forward_one_step(
self.embed(ys), self._target_mask(ys), cache=batch_state
)
h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)

View File

@ -1,24 +1,3 @@
def split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:

View File

@ -5,30 +5,52 @@ The audio data is in streaming, the asr inference process is in offline.
## Steps
Step 1) Prepare server environment (on server).
Step 1) Prepare server environment (on server).
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Install modelscope and funasr with pip or with cuda-docker image.
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Option 1: Install modelscope and funasr with [pip](https://github.com/alibaba-damo-academy/FunASR#installation)
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Option 2: or install with cuda-docker image as:
```
# Optional, modelscope cuda docker is preferred.
CID=`docker run --network host -d -it --gpus '"device=0"' registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.2.0`
echo $CID
docker exec -it $CID /bin/bash
cd /opt/conda/lib/python3.7/site-packages/funasr/runtime/python/grpc
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Get funasr source code and get into grpc directory.
```
git clone https://github.com/alibaba-damo-academy/FunASR
cd FunASR/funasr/runtime/python/grpc/
```
Step 2) Generate protobuf file (for server and client).
Step 2) Optional, generate protobuf file (run on server, the two generated pb files are both used for server and client).
```
# Optional, paraformer_pb2.py and paraformer_pb2_grpc.py are already generated.
# Optional, Install dependency.
python -m pip install grpcio grpcio-tools
```
```
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
Step 3) Start grpc server (on server).
```
# Optional, Install dependency.
python -m pip install grpcio grpcio-tools
```
```
# Start server.
python grpc_main_server.py --port 10095
```
Step 4) Start grpc client (on client with microphone).
```
# Install dependency. Optional.
python -m pip install pyaudio webrtcvad
# Optional, Install dependency.
python -m pip install pyaudio webrtcvad grpcio grpcio-tools
```
```
# Start client.
@ -41,7 +63,7 @@ python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
## Reference
We borrow or refer to some code from:
We borrow from or refer to some code as:
1)https://github.com/wenet-e2e/wenet/tree/main/runtime/core/grpc