update code

This commit is contained in:
shixian.shi 2024-01-05 16:11:04 +08:00
parent e9a015e79a
commit ab122d5652

View File

@ -4,20 +4,17 @@ import torch
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional from typing import Union, Dict, List, Tuple, Optional
from funasr.models.paraformer.cif_predictor import mae_loss from funasr.register import tables
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils import postprocess_utils from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.models.ctc.ctc import CTC from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video
@tables.register("model_classes", "monotonicaligner") @tables.register("model_classes", "monotonicaligner")
class MonotonicAligner(torch.nn.Module): class MonotonicAligner(torch.nn.Module):
""" """
@ -25,7 +22,6 @@ class MonotonicAligner(torch.nn.Module):
Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
https://arxiv.org/abs/2301.12343 https://arxiv.org/abs/2301.12343
""" """
def __init__( def __init__(
self, self,
input_size: int = 80, input_size: int = 80,
@ -41,7 +37,6 @@ class MonotonicAligner(torch.nn.Module):
length_normalized_loss: bool = False, length_normalized_loss: bool = False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
if specaug is not None: if specaug is not None:
@ -155,7 +150,6 @@ class MonotonicAligner(torch.nn.Module):
frontend=None, frontend=None,
**kwargs, **kwargs,
): ):
meta_data = {} meta_data = {}
# extract fbank feats # extract fbank feats
time1 = time.perf_counter() time1 = time.perf_counter()
@ -190,8 +184,7 @@ class MonotonicAligner(torch.nn.Module):
timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3], timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
us_peak[:encoder_out_lens[i] * 3], us_peak[:encoder_out_lens[i] * 3],
copy.copy(token)) copy.copy(token))
text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed, result_i = {"key": key[i], "text": text_postprocessed,
"timestamp": time_stamp_postprocessed, "timestamp": time_stamp_postprocessed,
} }