Merge pull request #218 from alibaba-damo-academy/dev_ts

update timestamp related codes and egs_modelscope
This commit is contained in:
zhifu gao 2023-03-13 17:47:56 +08:00 committed by GitHub
commit 0a729038cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 348 additions and 88 deletions

View File

@ -0,0 +1,25 @@
# ModelScope Model
## How to finetune and infer using a pretrained ModelScope Model
### Inference
Or you can use the finetuned model for inference directly.
- Setting parameters in `infer.py`
- <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- <strong>text_in:</strong> # support text, text url.
- <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
- Then you can run the pipeline to infer with:
```python
python infer.py
```
Modify inference related parameters in vad.yaml.
- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
- The value tends to -1, the greater probability of noise being judged as speech
- The value tends to 1, the greater probability of speech being judged as noise

View File

@ -0,0 +1,12 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_pipline = pipeline(
task=Tasks.speech_timestamp,
model='damo/speech_timestamp_prediction-v1-16k-offline',
output_dir='./tmp')
rec_result = inference_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢')
print(rec_result)

View File

@ -42,7 +42,7 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
class Speech2Text:
@ -245,7 +245,7 @@ class Speech2Text:
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
pre_token_length) # test no bias cif2
results = []
@ -291,7 +291,10 @@ class Speech2Text:
text = None
if isinstance(self.asr_model, BiCifParaformer):
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
us_peaks[i],
copy.copy(token),
vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
else:
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))

View File

@ -44,11 +44,10 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.utils.timestamp_tools import time_stamp_sentence
header_colors = '\033[95m'
end_colors = '\033[0m'
@ -257,7 +256,7 @@ class Speech2Text:
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
pre_token_length) # test no bias cif2
results = []
@ -303,7 +302,10 @@ class Speech2Text:
text = None
if isinstance(self.asr_model, BiCifParaformer):
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
us_peaks[i],
copy.copy(token),
vad_offset=begin_time)
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
else:
results.append((text, token, token_int, enc_len_batch_total, lfr_factor))

View File

@ -28,7 +28,9 @@ def parse_args(mode):
elif mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
elif mode == "mfcca":
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
elif mode == "tp":
from funasr.tasks.asr import ASRTaskAligner as ASRTask
else:
raise ValueError("Unknown mode: {}".format(mode))
parser = ASRTask.get_parser()

View File

@ -28,6 +28,8 @@ from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
header_colors = '\033[95m'
end_colors = '\033[0m'
@ -38,61 +40,6 @@ global_sample_rate: Union[int, Dict[Any, int]] = {
'model_fs': 16000
}
def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list):
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
if len(us_cif_peak.shape) == 2:
alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
else:
alphas, cif_peak = us_alphas, us_cif_peak
num_frames = cif_peak.shape[0]
if char_list[-1] == '</s>':
char_list = char_list[:-1]
# char_list = [i for i in text]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2 # total offset
num_peak = len(fire_place)
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
new_char_list.append('<sil>')
# tokens timestamp
for i in range(len(fire_place)-1):
new_char_list.append(char_list[i])
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION:
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
new_char_list.append('<sil>')
# tail token and end silence
# new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) * 0.5
# _end = fire_place[-1]
timestamp_list[-1][1] = _end*TIME_RATE
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames*TIME_RATE
assert len(new_char_list) == len(timestamp_list)
res_str = ""
for char, timestamp in zip(new_char_list, timestamp_list):
res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != '<sil>':
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_str, res
class SpeechText2Timestamp:
def __init__(
@ -315,7 +262,7 @@ def inference_modelscope(
for batch_id in range(_bs):
key = keys[batch_id]
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token)
ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
logging.warning(ts_str)
item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
tp_result_list.append(item)

View File

@ -926,10 +926,10 @@ class BiCifParaformer(Paraformer):
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
encoder_out_mask,
token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
self,

175
funasr/models/e2e_tp.py Normal file
View File

@ -0,0 +1,175 @@
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class TimestampPredictor(AbsESPnetModel):
"""
Author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
frontend: Optional[AbsFrontend],
encoder: AbsEncoder,
predictor: CifPredictorV3,
predictor_bias: int = 0,
token_list=None,
):
assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.frontend = frontend
self.encoder = encoder
self.encoder.interctc_use_conditioning = False
self.predictor = predictor
self.predictor_bias = predictor_bias
self.criterion_pre = mae_loss()
self.token_list = token_list
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, :speech_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
if self.predictor_bias == 1:
_, text = add_sos_eos(text, 1, 2, -1)
text_lengths = text_lengths + self.predictor_bias
_, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
loss = loss_pre
stats = dict()
# Collect Attn branch stats
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
encoder_out_mask,
token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}

View File

@ -40,6 +40,7 @@ from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
@ -124,6 +125,7 @@ model_choices = ClassChoices(
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
),
type_check=AbsESPnetModel,
default="asr",
@ -1245,9 +1247,87 @@ class ASRTaskMFCCA(ASRTask):
class ASRTaskAligner(ASRTaskParaformer):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --model and --model_conf
model_choices,
# --encoder and --encoder_conf
encoder_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 3. Predictor
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
# 10. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("asr")
# 8. Build model
model = model_class(
frontend=frontend,
encoder=encoder,
predictor=predictor,
token_list=token_list,
**args.model_conf,
)
# 11. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ("speech", "text")
return retval
return retval

View File

@ -5,55 +5,69 @@ import numpy as np
from typing import Any, List, Tuple, Union
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
def ts_prediction_lfr6_standard(us_alphas,
us_peaks,
char_list,
vad_offset=0.0,
force_time_shift=-1.5
):
if not len(char_list):
return []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
if len(us_alphas.shape) == 3:
alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
if len(us_alphas.shape) == 2:
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
alphas, cif_peak = us_alphas, us_cif_peak
num_frames = cif_peak.shape[0]
_, peaks = us_alphas, us_peaks
num_frames = peaks.shape[0]
if char_list[-1] == '</s>':
char_list = char_list[:-1]
# char_list = [i for i in text]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
num_peak = len(fire_place)
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
char_list.insert(0, '<sil>')
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
new_char_list.append('<sil>')
# tokens timestamp
for i in range(len(fire_place)-1):
# the peak is always a little ahead of the start time
# timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
# cut the duration to token and sil of the 0-weight frames last long
new_char_list.append(char_list[i])
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
new_char_list.append('<sil>')
# tail token and end silence
# new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) / 2
_end = (num_frames + fire_place[-1]) * 0.5
# _end = fire_place[-1]
timestamp_list[-1][1] = _end*TIME_RATE
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
char_list.append("<sil>")
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames*TIME_RATE
if begin_time: # add offset time in model with vad
if vad_offset: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(char_list, timestamp_list):
res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
for char, timestamp in zip(new_char_list, timestamp_list):
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
for char, timestamp in zip(char_list, timestamp_list):
for char, timestamp in zip(new_char_list, timestamp_list):
if char != '<sil>':
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res
return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):