FunASR/funasr/models/sense_voice/model.py
2024-06-17 13:36:22 +08:00

1663 lines
60 KiB
Python

import logging
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import types
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.cuda.amp import autocast
from funasr.metrics.compute_acc import compute_accuracy
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.train_utils.device_funcs import force_gatherable
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
@tables.register("model_classes", "SenseVoice")
class SenseVoice(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
dims = kwargs.get("dims", {})
dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)
# encoder
model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
model.encoder.use_padmask = kwargs.get("use_padmask", True)
from .encoder import sense_voice_encode_forward
model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
# decoder
model.decoder.use_padmask = kwargs.get("use_padmask", True)
from .decoder import sense_voice_decode_forward
model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder)
self.model = model
self.encoder_output_size = self.model.dims.n_audio_state
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
self.ignore_id = kwargs.get("ignore_id", -1)
self.vocab_size = kwargs.get("vocab_size", -1)
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
self.criterion_att = LabelSmoothingLoss(
size=self.vocab_size,
padding_idx=self.ignore_id,
smoothing=kwargs.get("lsm_weight", 0.0),
normalize_length=self.length_normalized_loss,
)
specaug = kwargs.get("specaug", None)
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**kwargs.get("specaug_conf", {}))
self.specaug = specaug
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
if self.activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
)
loss = loss_att
stats = {}
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
stats = {}
# 1. Forward decoder
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
# 2. Compute attention loss
mask = torch.ones_like(ys_pad) * (-1)
ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
ys_pad_mask[ys_pad_mask == 0] = -1
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
with torch.no_grad():
preds = torch.argmax(decoder_out, -1)
acc_att = compute_accuracy(
preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
)
return loss_att, acc_att, None, None
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
DecodingOptions = kwargs.get("DecodingOptions", {})
task = DecodingOptions.get("task", "ASR")
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
DecodingOptions["initial_prompt"] = initial_prompt
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
DecodingOptions["language"] = language
DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
if "without_timestamps" not in DecodingOptions:
DecodingOptions["without_timestamps"] = True
options = whisper.DecodingOptions(**DecodingOptions)
result = whisper.decode(self.model, speech, options)
text = f"{result.text}"
results = []
result_i = {"key": key[0], "text": text}
results.append(result_i)
return results, meta_data
@tables.register("model_classes", "SenseVoiceRWKV")
class SenseVoiceRWKV(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
dims = kwargs.get("dims", {})
dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)
# encoder
model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
model.encoder.use_padmask = kwargs.get("use_padmask", True)
from .encoder import sense_voice_encode_forward
model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
# decoder
del model.decoder
decoder = kwargs.get("decoder", "SenseVoiceDecoder")
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
n_vocab=dims.n_vocab,
n_ctx=dims.n_text_ctx,
n_state=dims.n_text_state,
n_head=dims.n_text_head,
n_layer=dims.n_text_layer,
**kwargs.get("decoder_conf"),
)
model.decoder = decoder
self.model = model
self.encoder_output_size = self.model.dims.n_audio_state
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
self.ignore_id = kwargs.get("ignore_id", -1)
self.vocab_size = kwargs.get("vocab_size", -1)
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
self.criterion_att = LabelSmoothingLoss(
size=self.vocab_size,
padding_idx=self.ignore_id,
smoothing=kwargs.get("lsm_weight", 0.0),
normalize_length=self.length_normalized_loss,
)
specaug = kwargs.get("specaug", None)
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**kwargs.get("specaug_conf", {}))
self.specaug = specaug
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
_, text_tokens = text.shape
if self.activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
)
loss = loss_att
stats = {}
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_frames"] = frames * batch_size
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["batch_size_x_tokens"] = text_tokens * batch_size
stats["batch_size_real_tokens"] = text_lengths.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
stats = {}
# 1. Forward decoder
# ys_pad: [sos, task, lid, text, eos]
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
# 2. Compute attention loss
mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
torch.int64
) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos]
# decoder_out: [sos, task, lid, text]
# ys_pad_mask: [-1, lid, text, eos]
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
with torch.no_grad():
preds = torch.argmax(decoder_out, -1)
acc_att = compute_accuracy(
preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
)
return loss_att, acc_att, None, None
def init_beam_search(
self,
**kwargs,
):
from .search import BeamSearch
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
scorers.update(
decoder=self.model.decoder,
length_bonus=LengthBonus(self.vocab_size),
)
weights = dict(
decoder=1.0,
ctc=0.0,
lm=0.0,
ngram=0.0,
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 5),
weights=weights,
scorers=scorers,
sos=None,
eos=None,
vocab_size=self.vocab_size,
token_list=None,
pre_beam_score_key="full",
)
self.beam_search = beam_search
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if not hasattr(self, "beam_search") or self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
DecodingOptions = kwargs.get("DecodingOptions", {})
task = DecodingOptions.get("task", "ASR")
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
sos_int = tokenizer.encode(sos, allowed_special="all")
eos = kwargs.get("model_conf").get("eos")
eos_int = tokenizer.encode(eos, allowed_special="all")
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
# Paramterts for rich decoding
self.beam_search.emo_unk = tokenizer.encode(
DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
)[0]
self.beam_search.emo_unk_score = 1
self.beam_search.emo_tokens = tokenizer.encode(
DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
allowed_special="all",
)
self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
self.beam_search.event_bg_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
allowed_special="all",
)
self.beam_search.event_ed_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
allowed_special="all",
)
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0],
maxlenratio=kwargs.get("maxlenratio", 0.0),
minlenratio=kwargs.get("minlenratio", 0.0),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# # remove blank symbol id, which is assumed to be 0
# token_int = list(
# filter(
# lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
# )
# )
# Change integer-ids to tokens
# token = tokenizer.ids2tokens(token_int)
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data
@tables.register("model_classes", "SenseVoiceFSMN")
class SenseVoiceFSMN(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
dims = kwargs.get("dims", {})
dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)
# encoder
model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
model.encoder.use_padmask = kwargs.get("use_padmask", True)
from .encoder import sense_voice_encode_forward
model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
# decoder
del model.decoder
decoder = kwargs.get("decoder", "SenseVoiceDecoder")
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
n_vocab=dims.n_vocab,
n_ctx=dims.n_text_ctx,
n_state=dims.n_text_state,
n_head=dims.n_text_head,
n_layer=dims.n_text_layer,
**kwargs.get("decoder_conf"),
)
model.decoder = decoder
self.model = model
self.encoder_output_size = self.model.dims.n_audio_state
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
self.ignore_id = kwargs.get("ignore_id", -1)
self.vocab_size = dims.n_vocab
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
self.criterion_att = LabelSmoothingLoss(
size=self.vocab_size,
padding_idx=self.ignore_id,
smoothing=kwargs.get("lsm_weight", 0.0),
normalize_length=self.length_normalized_loss,
)
specaug = kwargs.get("specaug", None)
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**kwargs.get("specaug_conf", {}))
self.specaug = specaug
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
_, text_tokens = text.shape
if self.activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
)
loss = loss_att
stats = {}
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_frames"] = frames * batch_size
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["batch_size_x_tokens"] = text_tokens * batch_size
stats["batch_size_real_tokens"] = text_lengths.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
stats = {}
# 1. Forward decoder
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
# decoder_out, _ = self.model.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# 2. Compute attention loss
mask = torch.ones_like(ys_pad) * (-1)
ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
ys_pad_mask[ys_pad_mask == 0] = -1
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
with torch.no_grad():
preds = torch.argmax(decoder_out, -1)
acc_att = compute_accuracy(
preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
)
return loss_att, acc_att, None, None
def init_beam_search(
self,
**kwargs,
):
from .search import BeamSearch
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
scorers.update(
decoder=self.model.decoder,
length_bonus=LengthBonus(self.vocab_size),
)
weights = dict(
decoder=1.0,
ctc=0.0,
lm=0.0,
ngram=0.0,
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 5),
weights=weights,
scorers=scorers,
sos=None,
eos=None,
vocab_size=self.vocab_size,
token_list=None,
pre_beam_score_key="full",
)
self.beam_search = beam_search
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if not hasattr(self, "beam_search") or self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
if (
isinstance(kwargs.get("data_type", None), (list, tuple))
and len(kwargs.get("data_type", [])) > 1
):
audio_sample_list, text_token_int_list = audio_sample_list
text_token_int = text_token_int_list[0]
else:
text_token_int = None
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
DecodingOptions = kwargs.get("DecodingOptions", {})
task = DecodingOptions.get("task", "ASR")
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
sos_int = tokenizer.encode(sos, allowed_special="all")
eos = kwargs.get("model_conf").get("eos")
eos_int = tokenizer.encode(eos, allowed_special="all")
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
# Paramterts for rich decoding
self.beam_search.emo_unk = tokenizer.encode(
DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
)[0]
self.beam_search.emo_unk_score = 1
self.beam_search.emo_tokens = tokenizer.encode(
DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
allowed_special="all",
)
self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
self.beam_search.event_bg_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
allowed_special="all",
)
self.beam_search.event_ed_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
allowed_special="all",
)
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
if text_token_int is not None:
i = 0
results = []
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"1best_recog"]
# 1. Forward decoder
ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
None, :
]
ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
kwargs["device"]
)[None, :]
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
token_int = decoder_out.argmax(-1)[0, :].tolist()
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0],
maxlenratio=kwargs.get("maxlenratio", 0.0),
minlenratio=kwargs.get("minlenratio", 0.0),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# # remove blank symbol id, which is assumed to be 0
# token_int = list(
# filter(
# lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
# )
# )
# Change integer-ids to tokens
# token = tokenizer.ids2tokens(token_int)
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data
@tables.register("model_classes", "SenseVoiceSANM")
class SenseVoiceSANM(nn.Module):
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.specaug = specaug
self.encoder = encoder
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
self.encoder_output_size = encoder_output_size
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
_, text_tokens = text.shape
if self.activation_checkpoint:
from torch.utils.checkpoint import checkpoint
encoder_out, encoder_out_lens = checkpoint(
self.encode, speech, speech_lengths, use_reentrant=False
)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
)
loss = loss_att
stats = {}
stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_frames"] = frames * batch_size
stats["batch_size_real_frames"] = speech_lengths.sum().item()
stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
stats["batch_size_x_tokens"] = text_tokens * batch_size
stats["batch_size_real_tokens"] = text_lengths.sum().item()
stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, (tuple, list)):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
**kwargs,
):
target_mask = kwargs.get("target_mask", None)
stats = {}
# 1. Forward decoder
ys_pad[ys_pad == -1] = 0
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
if isinstance(decoder_out, (list, tuple)):
decoder_out = decoder_out[0]
# 2. Compute attention loss
mask = torch.ones_like(ys_pad) * (-1)
ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
ys_pad_mask[ys_pad_mask == 0] = -1
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
with torch.no_grad():
preds = torch.argmax(decoder_out, -1)
acc_att = compute_accuracy(
preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
)
return loss_att, acc_att, None, None
def init_beam_search(
self,
**kwargs,
):
from .search import BeamSearch
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(self.vocab_size),
)
weights = dict(
decoder=1.0,
ctc=0.0,
lm=0.0,
ngram=0.0,
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 5),
weights=weights,
scorers=scorers,
sos=None,
eos=None,
vocab_size=self.vocab_size,
token_list=None,
pre_beam_score_key="full",
)
self.beam_search = beam_search
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if not hasattr(self, "beam_search") or self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs if hasattr(frontend, "fs") else 16000,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
if (
isinstance(kwargs.get("data_type", None), (list, tuple))
and len(kwargs.get("data_type", [])) > 1
):
audio_sample_list, text_token_int_list = audio_sample_list
text_token_int = text_token_int_list[0]
else:
text_token_int = None
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
DecodingOptions = kwargs.get("DecodingOptions", {})
task = DecodingOptions.get("task", "ASR")
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
sos = kwargs.get("model_conf").get("sos")
if isinstance(sos, str):
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
sos_int = tokenizer.encode(sos, allowed_special="all")
else:
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
initial_prompt = kwargs.get("initial_prompt", f"{task}")
initial_prompt_lid = (
f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
)
initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
sos_int = [sos] + initial_prompt_lid_int
eos = kwargs.get("model_conf").get("eos")
if isinstance(eos, str):
eos_int = tokenizer.encode(eos, allowed_special="all")
else:
eos_int = [eos]
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
# Paramterts for rich decoding
self.beam_search.emo_unk = tokenizer.encode(
DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
)[0]
self.beam_search.emo_unk_score = 1
self.beam_search.emo_tokens = tokenizer.encode(
DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
allowed_special="all",
)
self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
self.beam_search.event_bg_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
allowed_special="all",
)
self.beam_search.event_ed_token = tokenizer.encode(
DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
allowed_special="all",
)
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
encoder_out, encoder_out_lens = self.encode(speech[None, :, :], speech_lengths)
if text_token_int is not None:
i = 0
results = []
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"1best_recog"]
# 1. Forward decoder
ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
None, :
]
ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
kwargs["device"]
)[None, :]
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
token_int = decoder_out.argmax(-1)[0, :].tolist()
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0],
maxlenratio=kwargs.get("maxlenratio", 0.0),
minlenratio=kwargs.get("minlenratio", 0.0),
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
b, n, d = encoder_out.size()
for i in range(b):
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# # remove blank symbol id, which is assumed to be 0
# token_int = list(
# filter(
# lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
# )
# )
# Change integer-ids to tokens
# token = tokenizer.ids2tokens(token_int)
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text
return results, meta_data
from funasr.models.paraformer.search import Hypothesis
from funasr.utils import postprocess_utils
@tables.register("model_classes", "SenseVoiceSANMCTC")
class SenseVoiceSANMCTC(nn.Module):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
ctc_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
length_normalized_loss: bool = False,
**kwargs,
):
super().__init__()
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.specaug = specaug
self.normalize = normalize
self.encoder = encoder
self.error_calculator = None
self.ctc = ctc
self.length_normalized_loss = length_normalized_loss
self.encoder_output_size = encoder_output_size
self.lid_dict = {"zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
self.textnorm_dict = {"withtextnorm": 14, "wotextnorm": 15}
self.embed = torch.nn.Embedding(8 + len(self.lid_dict) + len(self.textnorm_dict), 560)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
stats = dict()
loss_ctc, cer_ctc = self._calc_ctc_loss(encoder_out, encoder_out_lens, text, text_lengths)
loss = loss_ctc
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
)
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
language = kwargs.get("language", None)
if language is not None:
language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
else:
language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
textnorm = kwargs.get("text_norm", "wotextnorm")
textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1)
speech = torch.cat((textnorm_query, speech), dim=1)
speech_lengths += 1
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
input_query = torch.cat((language_query, event_emo_query), dim=1)
speech = torch.cat((input_query, speech), dim=1)
speech_lengths += 3
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
ctc_logits = self.ctc.log_softmax(encoder_out)
results = []
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key * b
for i in range(b):
x = ctc_logits[i, : encoder_out_lens[i], :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device)
nbest_hyps = [Hypothesis(yseq=yseq)]
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(
filter(
lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
)
)
# Change integer-ids to tokens
text = tokenizer.decode(token_int)
result_i = {"key": key[i], "text": text}
results.append(result_i)
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["text"][key[i]] = text_postprocessed
return results, meta_data