FunASR/funasr/models/fsmn_kws/model.py
zhifu gao 2196844d1d
Dev kws (#2105)
* multi tokenizer

* support fsmn_kws, fsmn_kws_mt, sanm_kws, sanm_kws_streaming training

* kws

---------

Co-authored-by: pengteng.spt <pengteng.spt@alibaba-inc.com>
2024-09-25 15:10:50 +08:00

286 lines
9.7 KiB
Python

#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import time
import torch
import logging
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@tables.register("model_classes", "FsmnKWS")
class FsmnKWS(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
encoder_output_size = encoder.output_size()
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
self.blank_id = blank_id
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
# self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
self.ctc = ctc
self.error_calculator = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats = dict()
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
loss = self.ctc_weight * loss_ctc
stats["cer"] = cer_ctc
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out = self.encoder(speech)
encoder_out_lens = speech_lengths
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def inference(
self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
)
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is not None:
speech_lengths = speech_lengths.squeeze(-1)
else:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
results = []
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
for i in range(encoder_out.size(0)):
x = encoder_out[i, :encoder_out_lens[i], :]
detect_result = self.kws_decoder.decode(x)
is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
if is_deted:
self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
det_info = "detected " + det_keyword + " " + str(det_score)
else:
self.writer["detect"][key[i]] = "rejected"
det_info = "rejected"
result_i = {"key": key[i], "text": det_info}
results.append(result_i)
return results, meta_data
@tables.register("model_classes", "FsmnKWSConvert")
class FsmnKWSConvert(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
vocab_size: int = -1,
blank_id: int = 0,
**kwargs,
):
super().__init__()
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
encoder_output_size = encoder.output_size()
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
self.blank_id = blank_id
self.vocab_size = vocab_size
self.ctc_weight = ctc_weight
self.encoder = encoder
self.ctc = ctc
self.error_calculator = None
def to_kaldi_net(self):
return self.encoder.to_kaldi_net()
def to_pytorch_net(self, kaldi_file):
return self.encoder.to_pytorch_net(kaldi_file)