Merge branch 'dev_gzf_funasr2' into main

This commit is contained in:
zhifu gao 2023-12-11 10:10:40 +08:00 committed by GitHub
commit c0008fd461
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1581 additions and 27 deletions

View File

@ -0,0 +1,298 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
LightweightConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import (
LightweightConvolutionTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.e2e_asr_bat import BATModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
# from funasr.models.paraformer import Paraformer
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
multichannelfrontend=MultiChannelFrontend,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
specaug_lfr=SpecAugLFR,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
# specaug_choices = {"specaug":SpecAug}
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
# model_choices = ClassChoices(
# "model",
# classes=dict(
# asr=ASRModel,
# uniasr=UniASR,
# paraformer=Paraformer,
# paraformer_online=ParaformerOnline,
# paraformer_bert=ParaformerBert,
# bicif_paraformer=BiCifParaformer,
# contextual_paraformer=ContextualParaformer,
# neatcontextual_paraformer=NeatContextualParaformer,
# mfcca=MFCCA,
# timestamp_prediction=TimestampPredictor,
# rnnt=TransducerModel,
# rnnt_unified=UnifiedTransducerModel,
# bat=BATModel,
# sa_asr=SAASRModel,
# ),
# type_check=None,
# default="asr",
# )
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
sinc=LightweightSincConvs,
linear=LinearProjection,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
chunk_conformer=ConformerChunkEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
encoder_choices2 = ClassChoices(
"encoder2",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
),
type_check=AbsEncoder,
default="rnn",
)
asr_encoder_choices = ClassChoices(
"asr_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
spk_encoder_choices = ClassChoices(
"spk_encoder",
classes=dict(
resnet34_diar=ResNet34Diar,
),
default="resnet34_diar",
)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
hugging_face_transformers=HuggingFaceTransformersPostEncoder,
),
type_check=AbsPostEncoder,
default=None,
optional=True,
)
decoder_choices = ClassChoices(
"decoder",
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
)
decoder_choices2 = ClassChoices(
"decoder2",
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
),
type_check=AbsDecoder,
default="rnn",
)
rnnt_decoder_choices = ClassChoices(
"rnnt_decoder",
classes=dict(
rnnt=RNNTDecoder,
),
type_check=RNNTDecoder,
default="rnnt",
)
joint_network_choices = ClassChoices(
name="joint_network",
classes=dict(
joint_network=JointNetwork,
),
default="joint_network",
optional=True,
)
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
cif_predictor=CifPredictor,
ctc_predictor=None,
cif_predictor_v2=CifPredictorV2,
cif_predictor_v3=CifPredictorV3,
bat_predictor=BATPredictor,
),
type_check=None,
default="cif_predictor",
optional=True,
)
predictor_choices2 = ClassChoices(
name="predictor2",
classes=dict(
cif_predictor=CifPredictor,
ctc_predictor=None,
cif_predictor_v2=CifPredictorV2,
),
type_check=None,
default="cif_predictor",
optional=True,
)
stride_conv_choices = ClassChoices(
name="stride_conv",
classes=dict(
stride_conv1d=Conv1dSubsampling
),
type_check=None,
default="stride_conv1d",
optional=True,
)

View File

View File

@ -0,0 +1,652 @@
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import torch.nn as nn
import random
import numpy as np
# from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
# from funasr.models.e2e_asr_common import ErrorCalculator
# from funasr.models.encoder.abs_encoder import AbsEncoder
# from funasr.models.frontend.abs_frontend import AbsFrontend
# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.predictor.cif import mae_loss
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
# from funasr.models.base_model import FunASRModel
# from funasr.models.predictor.cif import CifPredictorV3
from funasr.cli.model_class_factory import *
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class Paraformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
# token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[str] = None,
frontend_conf: Optional[Dict] = None,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
decoder: str = None,
decoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
predictor: str = None,
predictor_conf: Optional[Dict] = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
# report_cer: bool = True,
# report_wer: bool = True,
# sym_space: str = "<space>",
# sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
# predictor=None,
predictor_weight: float = 0.0,
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
**kwargs,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__()
# import pdb;
# pdb.set_trace()
if frontend is not None:
frontend_class = frontend_choices.get_class(frontend)
frontend = frontend_class(**frontend_conf)
if specaug is not None:
specaug_class = specaug_choices.get_class(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = normalize_choices.get_class(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = encoder_choices.get_class(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
if decoder is not None:
decoder_class = decoder_choices.get_class(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
if ctc_weight > 0.0:
if ctc_conf is None:
ctc_conf = {}
ctc = CTC(
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
if predictor is not None:
predictor_class = predictor_choices.get_class(predictor)
predictor = predictor_class(**predictor_conf)
# note that eos is the same as sos (equivalent ID)
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.interctc_weight = interctc_weight
# self.token_list = token_list.copy()
#
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
# self.preencoder = preencoder
# self.postencoder = postencoder
self.encoder = encoder
#
# if not hasattr(self.encoder, "interctc_use_conditioning"):
# self.encoder.interctc_use_conditioning = False
# if self.encoder.interctc_use_conditioning:
# self.encoder.conditioning_layer = torch.nn.Linear(
# vocab_size, self.encoder.output_size()
# )
#
# self.error_calculator = None
#
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
#
# self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
self.predictor = predictor
self.predictor_weight = predictor_weight
self.predictor_bias = predictor_bias
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
# self.step_cur = 0
#
self.share_embedding = share_embedding
if self.share_embedding:
self.decoder.embed = None
self.use_1st_decoder_loss = use_1st_decoder_loss
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
decoding_ind: int
"""
decoding_ind = kwargs.get("kwargs", None)
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# # for data-parallel
# text = text[:, : text_lengths.max()]
# speech = speech[:, :speech_lengths.max()]
# 1. Encoder
if hasattr(self.encoder, "overlap_chunk_cls"):
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
if self.use_1st_decoder_loss and pre_loss_att is not None:
loss = loss + (1 - self.ctc_weight) * pre_loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# # 1. Extract feats
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(speech, speech_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# # Pre-encoder, e.g. used for raw input data
# if self.preencoder is not None:
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
if hasattr(self.encoder, "overlap_chunk_cls"):
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc, ind=ind
)
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
encoder_out_lens,
chunk_outs=None)
else:
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
if hasattr(self.encoder, "overlap_chunk_cls"):
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
encoder_out_lens,
chunk_outs=None)
else:
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# # Post-encoder, e.g. NLU
# if self.postencoder is not None:
# encoder_out, encoder_out_lens = self.postencoder(
# encoder_out, encoder_out_lens
# )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
) # [batch, seqlen, dim]
batch_size = decoder_out.size(0)
decoder_num_class = decoder_out.size(2)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
decoder_out.view(-1, decoder_num_class),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
)
nll = nll.view(batch_size, -1)
nll = nll.sum(dim=1)
assert nll.size(0) == batch_size
return nll
def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num = encoder_out.size(0)
if total_num <= batch_size:
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
else:
nll = []
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
batch_ys_pad = ys_pad[start_idx:end_idx, :]
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
batch_nll = self.nll(
batch_encoder_out,
batch_encoder_out_lens,
batch_ys_pad,
batch_ys_pad_lens,
)
nll.append(batch_nll)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nll)
assert nll.size(0) == total_num
return nll
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
ignore_id=self.ignore_id)
# 0. sampler
decoder_out_1st = None
pre_loss_att = None
if self.sampling_ratio > 0.0:
if self.use_1st_decoder_loss:
sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
pre_acoustic_embeds)
else:
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
pre_acoustic_embeds)
else:
if self.step_cur < 2:
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
sematic_embeds = pre_acoustic_embeds
# 1. Forward decoder
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_pad)
acc_att = th_accuracy(
decoder_out_1st.view(-1, self.vocab_size),
ys_pad,
ignore_label=self.ignore_id,
)
loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out_1st.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
for li in range(bsz):
target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
if target_num > 0:
input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
input_mask = input_mask.eq(1)
input_mask = input_mask.masked_fill(~nonpad_positions, False)
input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
input_mask_expand_dim, 0)
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
for li in range(bsz):
target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
if target_num > 0:
input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
input_mask = input_mask.eq(1)
input_mask = input_mask.masked_fill(~nonpad_positions, False)
input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
input_mask_expand_dim, 0)
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc

163
funasr/cli/train_cli.py Normal file
View File

@ -0,0 +1,163 @@
import argparse
import logging
import os
import sys
from io import BytesIO
from collections.abc import Sequence
import torch
import hydra
from omegaconf import DictConfig, OmegaConf
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
# from funasr.model_class_factory1 import model_choices
from funasr.modules.lora.utils import mark_only_lora_as_trainable
from funasr.optimizers import optim_choices
from funasr.schedulers import scheduler_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.initialize import initialize
from funasr.datasets.data_sampler import BatchSampler
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.tokenizer.funtoken import build_tokenizer
from funasr.datasets.dataset_jsonl import AudioDataset
from funasr.cli.trainer import Trainer
# from funasr.utils.load_fr_py import load_class_from_path
from funasr.utils.dynamic_import import dynamic_import
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def preprocess_config(cfg: DictConfig):
for key, value in cfg.items():
if value == 'None':
cfg[key] = None
@hydra.main()
def main(kwargs: DictConfig):
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", None)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
# build_tokenizer
tokenizer = build_tokenizer(
token_type=kwargs.get("token_type", "char"),
bpemodel=kwargs.get("bpemodel", None),
delimiter=kwargs.get("delimiter", None),
space_symbol=kwargs.get("space_symbol", "<space>"),
non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
g2p_type=kwargs.get("g2p_type", None),
token_list=kwargs.get("token_list", None),
unk_symbol=kwargs.get("unk_symbol", "<unk>"),
)
# import pdb;
# pdb.set_trace()
# build model
# model_class = model_choices.get_class(kwargs.get("model", "asr"))
# model_class = load_class_from_path(kwargs.get("model").split(":"))
model_class = dynamic_import(kwargs.get("model"))
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
frontend = model.frontend
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
init_param = eval(init_param)
if isinstance(init_param, Sequence):
init_param = (init_param,)
logging.info("init_param is not None: ", init_param)
for p in init_param:
logging.info(f"Loading pretrained params from {p}")
load_pretrained_model(
model=model,
init_param=p,
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
oss_bucket=kwargs.get("oss_bucket", None),
)
else:
initialize(model, kwargs.get("init", "kaiming_normal"))
# import pdb;
# pdb.set_trace()
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
freeze_param = eval(freeze_param)
if isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: ", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
if use_ddp:
model = model.cuda(local_rank)
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
model = FSDP(model).cuda(local_rank)
else:
model = model.to(device=kwargs.get("device", "cuda"))
# optim
optim = kwargs.get("optim", "adam")
assert optim in optim_choices
optim_class = optim_choices.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_choices
scheduler_class = scheduler_choices.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
# dataset
dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
# dataloader
batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
num_workers=kwargs.get("num_workers", 0),
pin_memory=True)
trainer = Trainer(
model=model,
optim=optim,
scheduler=scheduler,
dataloader_train=dataloader_tr,
dataloader_val=None,
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
**kwargs.get("train_conf"),
)
trainer.run()
if use_ddp or use_fsdp:
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()

199
funasr/cli/trainer.py Normal file
View File

@ -0,0 +1,199 @@
import torch
import os
from funasr.torch_utils.device_funcs import to_device
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from funasr.torch_utils.recursive_op import recursive_average
class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(self, model,
optim,
scheduler,
dataloader_train,
dataloader_val,
local_rank,
use_ddp=False,
use_fsdp=False,
**kwargs):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Args:
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
**kwargs: Additional keyword arguments:
max_epoch (int): The maximum number of epochs for training.
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
self.model = model
self.optim = optim
self.scheduler = scheduler
self.dataloader_train = dataloader_train
self.dataloader_val = dataloader_val
self.output_dir = kwargs.get('output_dir', './')
self.resume = kwargs.get('resume', None)
self.start_epoch = 1
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = torch.device("cuda", local_rank)
self.kwargs = kwargs
if self.resume:
self._resume_checkpoint(self.resume)
def _save_checkpoint(self, epoch):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
state = {
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
torch.save(state, filename)
print(f'Checkpoint saved to {filename}')
def _resume_checkpoint(self, resume_path):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if os.path.isfile(resume_path):
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.model.load_state_dict(checkpoint['state_dict'])
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
else:
print(f"No checkpoint found at '{resume_path}', starting from scratch")
def run(self):
"""
Starts the training process, iterating over epochs, training the model,
and saving checkpoints at the end of each epoch.
"""
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
if dist.get_rank() == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
def _train_epoch(self, epoch):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
self.model.train()
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
dynamic_ncols=True)
# Set the number of steps for gradient accumulation
accum_grad = self.kwargs.get("accum_grad", 1)
# Initialize the gradient accumulation
self.optim.zero_grad()
for batch_idx, batch in enumerate(self.dataloader_train):
batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
retval = self.model(**batch)
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
# if distributed, this method can also apply all_reduce()
stats, weight = recursive_average(stats, weight, distributed=True)
# Now weight is summation over all workers
loss /= weight
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
loss.backward()
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
# Perform gradient clipping if it is set
if self.kwargs.get("grad_clip", None) is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.kwargs.get("grad_clip", 10.0),
norm_type=self.kwargs.get("grad_clip_type", 2.0),
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
self.optim.zero_grad() # Reset gradients
continue
# Execute an optimization step (update model parameters)
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
self.optim.zero_grad()
pbar.update(1)
if self.local_rank == 0:
pbar.set_description(
f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float():.3f}, {[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]})")
pbar.close()
def _validate_epoch(self, epoch):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
self.model.eval()
with torch.no_grad():
for data, target in self.dataloader_val:
# Implement the model validation steps here
pass

View File

@ -4,17 +4,17 @@ import numpy as np
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
# self.batch_size_type = args.batch_size_type
# self.batch_type = args.batch_type
# self.batch_size = args.batch_size
# self.sort_size = args.sort_size
# self.max_length_token = args.max_length_token
self.batch_size_type = batch_size_type
self.batch_type = batch_type
self.batch_size = batch_size
self.sort_size = sort_size
self.max_length_token = kwargs.get("max_length_token", 5000)
@ -26,7 +26,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
return self.total_samples
def __iter__(self):
print("in sampler")
# print("in sampler")
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
@ -36,7 +36,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
num_sample = 0
iter_num = (self.total_samples-1) // self.sort_size + 1
print("iter_num: ", iter_num)
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
@ -46,8 +46,8 @@ class BatchSampler(torch.utils.data.BatchSampler):
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
self.dataset.indexed_dataset[idx_map]["target_len"]
sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \
self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map])
datalen_with_index.append([idx, sample_len_cur])
@ -59,7 +59,7 @@ class BatchSampler(torch.utils.data.BatchSampler):
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_size_type == 'token':
if self.batch_type == 'token':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)

View File

@ -38,16 +38,13 @@ dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer, token_id_c
batch_sampler = BatchSampler(dataset)
def collator(samples: list = None):
return samples
if __name__ == "__main__":
dataloader_tr = torch.utils.data.DataLoader(dataset,
collate_fn=dataset.collator,
batch_sampler=batch_sampler,
shuffle=False,
num_workers=8,
num_workers=0,
pin_memory=True)
print(len(dataset))

View File

@ -78,21 +78,26 @@ class IndexedDatasetJsonl(torch.utils.data.Dataset):
def __getitem__(self, index):
return self.contents[index]
def get_source_len(self, data_dict):
return data_dict["source_len"]
def get_target_len(self, data_dict):
return data_dict["target_len"] if "target_len" in data_dict else 0
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None):
def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs):
super().__init__()
self.indexed_dataset = IndexedDatasetJsonl(path)
self.frontend = frontend.forward
self.fs = 16000 if frontend is None else frontend.fs
self.data_type = "sound"
self.tokenizer = tokenizer
self.token_id_converter = token_id_converter
self.int_pad_value = -1
self.float_pad_value = 0.0
self.int_pad_value = int_pad_value
self.float_pad_value = float_pad_value
@ -108,8 +113,7 @@ class AudioDataset(torch.utils.data.Dataset):
data_src = load_audio(source, fs=self.fs)
speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
target = item["target"]
text = self.tokenizer.text2tokens(target)
ids = self.token_id_converter.tokens2ids(text)
ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)

View File

@ -361,6 +361,7 @@ class CommonPreprocessor(AbsPreprocessor):
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
return data

View File

@ -223,6 +223,7 @@ class ASRModel(FunASRModel):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight

View File

@ -234,6 +234,7 @@ class NeatContextualParaformer(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight

View File

@ -256,6 +256,7 @@ class Paraformer(FunASRModel):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -868,6 +869,7 @@ class ParaformerOnline(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -1495,6 +1497,7 @@ class ParaformerBert(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -1766,6 +1769,7 @@ class BiCifParaformer(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -1968,6 +1972,7 @@ class ContextualParaformer(Paraformer):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + self.predictor_bias).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@ -2262,4 +2267,4 @@ class ContextualParaformer(Paraformer):
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
return var_dict_torch_update
return var_dict_torch_update

View File

@ -443,6 +443,7 @@ class UniASR(FunASRModel):
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight

View File

@ -347,7 +347,7 @@ def th_accuracy(pad_outputs, pad_targets, ignore_label):
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:

View File

@ -0,0 +1,17 @@
import torch
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
optim_choices = dict(
adam=torch.optim.Adam,
fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
adagrad=torch.optim.Adagrad,
adamax=torch.optim.Adamax,
asgd=torch.optim.ASGD,
lbfgs=torch.optim.LBFGS,
rmsprop=torch.optim.RMSprop,
rprop=torch.optim.Rprop,
)

View File

@ -0,0 +1,23 @@
import torch
import torch.multiprocessing
import torch.nn
import torch.optim
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
scheduler_choices = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
steplr=torch.optim.lr_scheduler.StepLR,
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)

View File

@ -2,7 +2,13 @@ from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
class AbsTokenizer(ABC):
@abstractmethod
@ -12,3 +18,71 @@ class AbsTokenizer(ABC):
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
class BaseTokenizer(ABC):
def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
unk_symbol: str = "<unk>",
**kwargs,
):
if token_list is not None:
if isinstance(token_list, (Path, str)):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with token_list.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
for i, t in enumerate(self.token_list):
if i == 3:
break
self.token_list_repr += f"{t}, "
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
self.token2id: Dict[str, int] = {}
for i, t in enumerate(self.token_list):
if t in self.token2id:
raise RuntimeError(f'Symbol "{t}" is duplicated')
self.token2id[t] = i
self.unk_symbol = unk_symbol
if self.unk_symbol not in self.token2id:
raise RuntimeError(
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
)
self.unk_id = self.token2id[self.unk_symbol]
def encode(self, text):
tokens = self.text2tokens(text)
text_ints = self.tokens2ids(tokens)
return text_ints
def decode(self, text_ints):
return self.ids2tokens(text_ints)
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError

View File

@ -1,7 +1,17 @@
from pathlib import Path
from typing import Iterable
from typing import Union
from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.char_tokenizer import CharTokenizer
@ -18,7 +28,8 @@ def build_tokenizer(
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
) -> AbsTokenizer:
**kwargs,
):
"""A helper function to instantiate Tokenizer"""
if token_type == "bpe":
if bpemodel is None:
@ -28,7 +39,7 @@ def build_tokenizer(
raise RuntimeError(
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
)
return SentencepiecesTokenizer(bpemodel)
return SentencepiecesTokenizer(bpemodel, **kwargs)
elif token_type == "word":
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
@ -38,13 +49,14 @@ def build_tokenizer(
remove_non_linguistic_symbols=True,
)
else:
return WordTokenizer(delimiter=delimiter)
return WordTokenizer(delimiter=delimiter, **kwargs)
elif token_type == "char":
return CharTokenizer(
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
**kwargs
)
elif token_type == "phn":
@ -53,6 +65,7 @@ def build_tokenizer(
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
**kwargs
)
else:

View File

@ -6,15 +6,17 @@ import warnings
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
class CharTokenizer(AbsTokenizer):
class CharTokenizer(BaseTokenizer):
def __init__(
self,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.space_symbol = space_symbol
if non_linguistic_symbols is None:
self.non_linguistic_symbols = set()

View File

@ -0,0 +1,75 @@
from pathlib import Path
from typing import Iterable
from typing import Union
from abc import ABC
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.char_tokenizer import CharTokenizer
from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
from funasr.tokenizer.word_tokenizer import WordTokenizer
def build_tokenizer(
token_type: str,
bpemodel: Union[Path, str, Iterable[str]] = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
**kwargs,
):
"""A helper function to instantiate Tokenizer"""
# import pdb;
# pdb.set_trace()
if token_type == "bpe":
if bpemodel is None:
raise ValueError('bpemodel is required if token_type = "bpe"')
if remove_non_linguistic_symbols:
raise RuntimeError(
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
)
return SentencepiecesTokenizer(bpemodel, **kwargs)
elif token_type == "word":
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
return WordTokenizer(
delimiter=delimiter,
non_linguistic_symbols=non_linguistic_symbols,
remove_non_linguistic_symbols=True,
)
else:
return WordTokenizer(delimiter=delimiter, **kwargs)
elif token_type == "char":
return CharTokenizer(
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
**kwargs
)
elif token_type == "phn":
return PhonemeTokenizer(
g2p_type=g2p_type,
non_linguistic_symbols=non_linguistic_symbols,
space_symbol=space_symbol,
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
**kwargs
)
else:
raise ValueError(
f"token_mode must be one of bpe, word, char or phn: " f"{token_type}"
)

View File

@ -363,6 +363,7 @@ class PhonemeTokenizer(AbsTokenizer):
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
**kwargs,
):
if g2p_type is None:
self.g2p = split_by_space

View File

@ -9,7 +9,7 @@ from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
def __init__(self, model: Union[Path, str], **kwargs):
self.model = str(model)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()

View File

@ -14,6 +14,7 @@ class WordTokenizer(AbsTokenizer):
delimiter: str = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
**kwargs,
):
self.delimiter = delimiter

View File

@ -0,0 +1,13 @@
import importlib
def dynamic_import(import_path):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
:return: imported class
"""
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)

View File

@ -0,0 +1,13 @@
import importlib.util
import sys
def load_class_from_path(model_path):
path, class_name = model_path
# import pdb;
# pdb.set_trace()
spec = importlib.util.spec_from_file_location("module.name", path)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return getattr(module, class_name)