FunASR/funasr/models/paraformer/model.py
2024-06-06 10:08:17 +08:00

605 lines
23 KiB
Python

#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import time
import copy
import torch
import logging
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@tables.register("model_classes", "Paraformer")
class Paraformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
decoder: str = None,
decoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
predictor: str = None,
predictor_conf: Optional[Dict] = None,
ctc_weight: float = 0.5,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
# report_cer: bool = True,
# report_wer: bool = True,
# sym_space: str = "<space>",
# sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
# predictor=None,
predictor_weight: float = 0.0,
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
if decoder is not None:
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
if ctc_weight > 0.0:
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
if predictor is not None:
predictor_class = tables.predictor_classes.get(predictor)
predictor = predictor_class(**predictor_conf)
# note that eos is the same as sos (equivalent ID)
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
# self.token_list = token_list.copy()
#
# self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
# self.preencoder = preencoder
# self.postencoder = postencoder
self.encoder = encoder
#
# if not hasattr(self.encoder, "interctc_use_conditioning"):
# self.encoder.interctc_use_conditioning = False
# if self.encoder.interctc_use_conditioning:
# self.encoder.conditioning_layer = torch.nn.Linear(
# vocab_size, self.encoder.output_size()
# )
#
# self.error_calculator = None
#
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
#
# self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
self.predictor = predictor
self.predictor_weight = predictor_weight
self.predictor_bias = predictor_bias
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
self.share_embedding = share_embedding
if self.share_embedding:
self.decoder.embed = None
self.use_1st_decoder_loss = use_1st_decoder_loss
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.error_calculator = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# decoder: Attention decoder branch
loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
else:
loss = (
self.ctc_weight * loss_ctc
+ (1 - self.ctc_weight) * loss_att
+ loss_pre * self.predictor_weight
)
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = (text_lengths + self.predictor_bias).sum()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(
encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id
)
# 0. sampler
decoder_out_1st = None
pre_loss_att = None
if self.sampling_ratio > 0.0:
sematic_embeds, decoder_out_1st = self.sampler(
encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds
)
else:
sematic_embeds = pre_acoustic_embeds
# 1. Forward decoder
decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_pad)
acc_att = th_accuracy(
decoder_out_1st.view(-1, self.vocab_size),
ys_pad,
ignore_label=self.ignore_id,
)
loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out_1st.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
for li in range(bsz):
target_num = (
((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio
).long()
if target_num > 0:
input_mask[li].scatter_(
dim=0,
index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
value=0,
)
input_mask = input_mask.eq(1)
input_mask = input_mask.masked_fill(~nonpad_positions, False)
input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
sematic_embeds = pre_acoustic_embeds.masked_fill(
~input_mask_expand_dim, 0
) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0)
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def init_beam_search(
self,
**kwargs,
):
from funasr.models.paraformer.search import BeamSearchPara
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
decoder=1.0 - kwargs.get("decoding_ctc_weight"),
ctc=kwargs.get("decoding_ctc_weight", 0.0),
lm=kwargs.get("lm_weight", 0.0),
ngram=kwargs.get("ngram_weight", 0.0),
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearchPara(
beam_size=kwargs.get("beam_size", 2),
weights=weights,
scorers=scorers,
sos=self.sos,
eos=self.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
)
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
# for scorer in scorers.values():
# if isinstance(scorer, torch.nn.Module):
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
self.beam_search = beam_search
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is not None:
speech_lengths = speech_lengths.squeeze(-1)
else:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
)
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
if kwargs.get("fp16", False):
speech = speech.half()
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
predictor_outs[0],
predictor_outs[1],
predictor_outs[2],
predictor_outs[3],
)
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.cal_decoder_with_predictor(
encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = encoder_out[i, : encoder_out_lens[i], :]
am_scores = decoder_out[i, : pre_token_length[i], :]
if self.beam_search is not None:
nbest_hyps = self.beam_search(
x=x,
am_scores=am_scores,
maxlenratio=kwargs.get("maxlenratio", 0.0),
minlenratio=kwargs.get("minlenratio", 0.0),
)
nbest_hyps = nbest_hyps[: self.nbest]
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(
filter(
lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
)
)
if tokenizer is not None:
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text_postprocessed = tokenizer.tokens2text(token)
if pred_timestamp:
timestamp_str, timestamp = ts_prediction_lfr6_standard(
pre_peak_index[i],
alphas[i],
copy.copy(token),
vad_offset=kwargs.get("begin_time", 0),
upsample_rate=1,
)
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
else:
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "text": text_postprocessed}
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["text"][key[i]] = text
ibest_writer["text"][key[i]] = text_postprocessed
else:
result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
return results, meta_data
def export(self, **kwargs):
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models