update code

This commit is contained in:
shixian.shi 2024-01-05 16:11:04 +08:00
parent e9a015e79a
commit ab122d5652

View File

@ -4,20 +4,17 @@ import torch
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video
@tables.register("model_classes", "monotonicaligner")
class MonotonicAligner(torch.nn.Module):
"""
@ -25,7 +22,6 @@ class MonotonicAligner(torch.nn.Module):
Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
https://arxiv.org/abs/2301.12343
"""
def __init__(
self,
input_size: int = 80,
@ -41,7 +37,6 @@ class MonotonicAligner(torch.nn.Module):
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
if specaug is not None:
@ -155,7 +150,6 @@ class MonotonicAligner(torch.nn.Module):
frontend=None,
**kwargs,
):
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
@ -190,8 +184,7 @@ class MonotonicAligner(torch.nn.Module):
timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
us_peak[:encoder_out_lens[i] * 3],
copy.copy(token))
text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
token, timestamp)
text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed,
"timestamp": time_stamp_postprocessed,
}