mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #218 from alibaba-damo-academy/dev_ts
update timestamp related codes and egs_modelscope
This commit is contained in:
commit
0a729038cf
@ -0,0 +1,25 @@
|
||||
# ModelScope Model
|
||||
|
||||
## How to finetune and infer using a pretrained ModelScope Model
|
||||
|
||||
### Inference
|
||||
|
||||
Or you can use the finetuned model for inference directly.
|
||||
|
||||
- Setting parameters in `infer.py`
|
||||
- <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
|
||||
- <strong>text_in:</strong> # support text, text url.
|
||||
- <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
|
||||
|
||||
- Then you can run the pipeline to infer with:
|
||||
```python
|
||||
python infer.py
|
||||
```
|
||||
|
||||
|
||||
Modify inference related parameters in vad.yaml.
|
||||
|
||||
- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
|
||||
- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
|
||||
- The value tends to -1, the greater probability of noise being judged as speech
|
||||
- The value tends to 1, the greater probability of speech being judged as noise
|
||||
@ -0,0 +1,12 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
inference_pipline = pipeline(
|
||||
task=Tasks.speech_timestamp,
|
||||
model='damo/speech_timestamp_prediction-v1-16k-offline',
|
||||
output_dir='./tmp')
|
||||
|
||||
rec_result = inference_pipline(
|
||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
|
||||
text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢')
|
||||
print(rec_result)
|
||||
@ -42,7 +42,7 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence
|
||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
|
||||
|
||||
class Speech2Text:
|
||||
@ -245,7 +245,7 @@ class Speech2Text:
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
if isinstance(self.asr_model, BiCifParaformer):
|
||||
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
||||
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
||||
pre_token_length) # test no bias cif2
|
||||
|
||||
results = []
|
||||
@ -291,7 +291,10 @@ class Speech2Text:
|
||||
text = None
|
||||
|
||||
if isinstance(self.asr_model, BiCifParaformer):
|
||||
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
|
||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
|
||||
us_peaks[i],
|
||||
copy.copy(token),
|
||||
vad_offset=begin_time)
|
||||
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
|
||||
else:
|
||||
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
|
||||
|
||||
@ -44,11 +44,10 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.tasks.vad import VADTask
|
||||
from funasr.bin.vad_inference import Speech2VadSegment
|
||||
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
|
||||
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
|
||||
from funasr.bin.punctuation_infer import Text2Punc
|
||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||
|
||||
from funasr.utils.timestamp_tools import time_stamp_sentence
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
@ -257,7 +256,7 @@ class Speech2Text:
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
if isinstance(self.asr_model, BiCifParaformer):
|
||||
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
||||
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
||||
pre_token_length) # test no bias cif2
|
||||
|
||||
results = []
|
||||
@ -303,7 +302,10 @@ class Speech2Text:
|
||||
text = None
|
||||
|
||||
if isinstance(self.asr_model, BiCifParaformer):
|
||||
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
|
||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
|
||||
us_peaks[i],
|
||||
copy.copy(token),
|
||||
vad_offset=begin_time)
|
||||
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
|
||||
else:
|
||||
results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
|
||||
|
||||
@ -28,7 +28,9 @@ def parse_args(mode):
|
||||
elif mode == "uniasr":
|
||||
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
||||
elif mode == "mfcca":
|
||||
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
||||
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
||||
elif mode == "tp":
|
||||
from funasr.tasks.asr import ASRTaskAligner as ASRTask
|
||||
else:
|
||||
raise ValueError("Unknown mode: {}".format(mode))
|
||||
parser = ASRTask.get_parser()
|
||||
|
||||
@ -28,6 +28,8 @@ from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
@ -38,61 +40,6 @@ global_sample_rate: Union[int, Dict[Any, int]] = {
|
||||
'model_fs': 16000
|
||||
}
|
||||
|
||||
def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list):
|
||||
START_END_THRESHOLD = 5
|
||||
MAX_TOKEN_DURATION = 12
|
||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||
if len(us_cif_peak.shape) == 2:
|
||||
alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
|
||||
else:
|
||||
alphas, cif_peak = us_alphas, us_cif_peak
|
||||
num_frames = cif_peak.shape[0]
|
||||
if char_list[-1] == '</s>':
|
||||
char_list = char_list[:-1]
|
||||
# char_list = [i for i in text]
|
||||
timestamp_list = []
|
||||
new_char_list = []
|
||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||
# so treat the frames between two peaks as the duration of the former token
|
||||
fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2 # total offset
|
||||
num_peak = len(fire_place)
|
||||
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||
# begin silence
|
||||
if fire_place[0] > START_END_THRESHOLD:
|
||||
# char_list.insert(0, '<sil>')
|
||||
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
|
||||
new_char_list.append('<sil>')
|
||||
# tokens timestamp
|
||||
for i in range(len(fire_place)-1):
|
||||
new_char_list.append(char_list[i])
|
||||
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION:
|
||||
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
else:
|
||||
# cut the duration to token and sil of the 0-weight frames last long
|
||||
_split = fire_place[i] + MAX_TOKEN_DURATION
|
||||
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
|
||||
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
new_char_list.append('<sil>')
|
||||
# tail token and end silence
|
||||
# new_char_list.append(char_list[-1])
|
||||
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
||||
_end = (num_frames + fire_place[-1]) * 0.5
|
||||
# _end = fire_place[-1]
|
||||
timestamp_list[-1][1] = _end*TIME_RATE
|
||||
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
|
||||
new_char_list.append("<sil>")
|
||||
else:
|
||||
timestamp_list[-1][1] = num_frames*TIME_RATE
|
||||
assert len(new_char_list) == len(timestamp_list)
|
||||
res_str = ""
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||
res = []
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
if char != '<sil>':
|
||||
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
||||
return res_str, res
|
||||
|
||||
|
||||
class SpeechText2Timestamp:
|
||||
def __init__(
|
||||
@ -315,7 +262,7 @@ def inference_modelscope(
|
||||
for batch_id in range(_bs):
|
||||
key = keys[batch_id]
|
||||
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
|
||||
ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token)
|
||||
ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
|
||||
logging.warning(ts_str)
|
||||
item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
|
||||
tp_result_list.append(item)
|
||||
|
||||
@ -926,10 +926,10 @@ class BiCifParaformer(Paraformer):
|
||||
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
|
||||
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
175
funasr/models/e2e_tp.py
Normal file
175
funasr/models/e2e_tp.py
Normal file
@ -0,0 +1,175 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.predictor.cif import mae_loss
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.modules.nets_utils import make_pad_mask, pad_list
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr.models.predictor.cif import CifPredictorV3
|
||||
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class TimestampPredictor(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[AbsFrontend],
|
||||
encoder: AbsEncoder,
|
||||
predictor: CifPredictorV3,
|
||||
predictor_bias: int = 0,
|
||||
token_list=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
|
||||
self.frontend = frontend
|
||||
self.encoder = encoder
|
||||
self.encoder.interctc_use_conditioning = False
|
||||
|
||||
self.predictor = predictor
|
||||
self.predictor_bias = predictor_bias
|
||||
self.criterion_pre = mae_loss()
|
||||
self.token_list = token_list
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
speech = speech[:, :speech_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
if self.predictor_bias == 1:
|
||||
_, text = add_sos_eos(text, 1, 2, -1)
|
||||
text_lengths = text_lengths + self.predictor_bias
|
||||
_, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
|
||||
|
||||
# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
|
||||
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
|
||||
|
||||
loss = loss_pre
|
||||
stats = dict()
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||
encoder_out_mask,
|
||||
token_num)
|
||||
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
@ -40,6 +40,7 @@ from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||
from funasr.models.e2e_asr import ESPnetASRModel
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||
from funasr.models.e2e_tp import TimestampPredictor
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
@ -124,6 +125,7 @@ model_choices = ClassChoices(
|
||||
bicif_paraformer=BiCifParaformer,
|
||||
contextual_paraformer=ContextualParaformer,
|
||||
mfcca=MFCCA,
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
),
|
||||
type_check=AbsESPnetModel,
|
||||
default="asr",
|
||||
@ -1245,9 +1247,87 @@ class ASRTaskMFCCA(ASRTask):
|
||||
|
||||
|
||||
class ASRTaskAligner(ASRTaskParaformer):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
|
||||
# 3. Predictor
|
||||
predictor_class = predictor_choices.get_class(args.predictor)
|
||||
predictor = predictor_class(**args.predictor_conf)
|
||||
|
||||
# 10. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("asr")
|
||||
|
||||
# 8. Build model
|
||||
model = model_class(
|
||||
frontend=frontend,
|
||||
encoder=encoder,
|
||||
predictor=predictor,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 11. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def required_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ("speech", "text")
|
||||
return retval
|
||||
return retval
|
||||
|
||||
@ -5,55 +5,69 @@ import numpy as np
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
|
||||
def ts_prediction_lfr6_standard(us_alphas,
|
||||
us_peaks,
|
||||
char_list,
|
||||
vad_offset=0.0,
|
||||
force_time_shift=-1.5
|
||||
):
|
||||
if not len(char_list):
|
||||
return []
|
||||
START_END_THRESHOLD = 5
|
||||
MAX_TOKEN_DURATION = 12
|
||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||
if len(us_alphas.shape) == 3:
|
||||
alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
|
||||
if len(us_alphas.shape) == 2:
|
||||
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||
else:
|
||||
alphas, cif_peak = us_alphas, us_cif_peak
|
||||
num_frames = cif_peak.shape[0]
|
||||
_, peaks = us_alphas, us_peaks
|
||||
num_frames = peaks.shape[0]
|
||||
if char_list[-1] == '</s>':
|
||||
char_list = char_list[:-1]
|
||||
# char_list = [i for i in text]
|
||||
timestamp_list = []
|
||||
new_char_list = []
|
||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||
# so treat the frames between two peaks as the duration of the former token
|
||||
fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
|
||||
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||
num_peak = len(fire_place)
|
||||
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||
# begin silence
|
||||
if fire_place[0] > START_END_THRESHOLD:
|
||||
char_list.insert(0, '<sil>')
|
||||
# char_list.insert(0, '<sil>')
|
||||
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
|
||||
new_char_list.append('<sil>')
|
||||
# tokens timestamp
|
||||
for i in range(len(fire_place)-1):
|
||||
# the peak is always a little ahead of the start time
|
||||
# timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
# cut the duration to token and sil of the 0-weight frames last long
|
||||
new_char_list.append(char_list[i])
|
||||
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
|
||||
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
else:
|
||||
# cut the duration to token and sil of the 0-weight frames last long
|
||||
_split = fire_place[i] + MAX_TOKEN_DURATION
|
||||
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
|
||||
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||
new_char_list.append('<sil>')
|
||||
# tail token and end silence
|
||||
# new_char_list.append(char_list[-1])
|
||||
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
||||
_end = (num_frames + fire_place[-1]) / 2
|
||||
_end = (num_frames + fire_place[-1]) * 0.5
|
||||
# _end = fire_place[-1]
|
||||
timestamp_list[-1][1] = _end*TIME_RATE
|
||||
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
|
||||
char_list.append("<sil>")
|
||||
new_char_list.append("<sil>")
|
||||
else:
|
||||
timestamp_list[-1][1] = num_frames*TIME_RATE
|
||||
if begin_time: # add offset time in model with vad
|
||||
if vad_offset: # add offset time in model with vad
|
||||
for i in range(len(timestamp_list)):
|
||||
timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
|
||||
timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
|
||||
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
|
||||
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||
res_txt = ""
|
||||
for char, timestamp in zip(char_list, timestamp_list):
|
||||
res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||
res = []
|
||||
for char, timestamp in zip(char_list, timestamp_list):
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
if char != '<sil>':
|
||||
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
||||
return res
|
||||
return res_txt, res
|
||||
|
||||
|
||||
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user