From 836d57bb6c08c76dada384d93ca0ee3cc5374f48 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 20 Dec 2023 17:03:23 +0800 Subject: [PATCH] update seaco paraformer --- funasr/models/paraformer/decoder.py | 58 ++- funasr/models/sanm/attention.py | 10 +- funasr/models/seaco_paraformer/__init__.py | 0 funasr/models/seaco_paraformer/model.py | 512 +++++++++++++++++++ funasr/models/seaco_paraformer/template.yaml | 151 ++++++ 5 files changed, 705 insertions(+), 26 deletions(-) create mode 100644 funasr/models/seaco_paraformer/__init__.py create mode 100644 funasr/models/seaco_paraformer/model.py create mode 100644 funasr/models/seaco_paraformer/template.yaml diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py index 3fe9d194c..f59ce4db8 100644 --- a/funasr/models/paraformer/decoder.py +++ b/funasr/models/paraformer/decoder.py @@ -68,6 +68,8 @@ class DecoderLayerSANM(nn.Module): if self.concat_after: self.concat_linear1 = nn.Linear(size + size, size) self.concat_linear2 = nn.Linear(size + size, size) + self.reserve_attn=False + self.attn_mat = [] def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): """Compute decoded features. @@ -104,8 +106,13 @@ class DecoderLayerSANM(nn.Module): residual = x if self.normalize_before: x = self.norm3(x) - - x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) + if self.reserve_attn: + x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True) + self.attn_mat.append(attn_mat) + else: + x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False) + x = residual + self.dropout(x_src_attn) + # x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) return x, tgt_mask, memory, memory_mask, cache @@ -213,6 +220,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, + wo_input_layer: bool = False, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, @@ -239,22 +247,24 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): ) attention_dim = encoder_output_size - - if input_layer == "embed": - self.embed = torch.nn.Sequential( - torch.nn.Embedding(vocab_size, attention_dim), - # pos_enc_class(attention_dim, positional_dropout_rate), - ) - elif input_layer == "linear": - self.embed = torch.nn.Sequential( - torch.nn.Linear(vocab_size, attention_dim), - torch.nn.LayerNorm(attention_dim), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate), - ) + if wo_input_layer: + self.embed = None else: - raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + # pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") self.normalize_before = normalize_before if self.normalize_before: @@ -324,6 +334,8 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, + return_hidden: bool = False, + return_both: bool= False, chunk_mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. @@ -365,12 +377,16 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): x, tgt_mask, memory, memory_mask ) if self.normalize_before: - x = self.after_norm(x) - if self.output_layer is not None: - x = self.output_layer(x) + hidden = self.after_norm(x) olens = tgt_mask.sum(1) - return x, olens + if self.output_layer is not None and return_hidden is False: + x = self.output_layer(hidden) + return x, olens + if return_both: + x = self.output_layer(hidden) + return x, hidden, olens + return hidden, olens def score(self, ys, state, x): """Score.""" diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py index f48617c93..10f0a3b23 100644 --- a/funasr/models/sanm/attention.py +++ b/funasr/models/sanm/attention.py @@ -449,7 +449,7 @@ class MultiHeadedAttentionCrossAtt(nn.Module): return q_h, k_h, v_h - def forward_attention(self, value, scores, mask): + def forward_attention(self, value, scores, mask, ret_attn=False): """Compute attention context vector. Args: @@ -476,16 +476,16 @@ class MultiHeadedAttentionCrossAtt(nn.Module): ) # (batch, head, time1, time2) else: self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - p_attn = self.dropout(self.attn) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) x = ( x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) ) # (batch, time1, d_model) - + if ret_attn: + return self.linear_out(x), self.attn # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) - def forward(self, x, memory, memory_mask): + def forward(self, x, memory, memory_mask, ret_attn=False): """Compute scaled dot product attention. Args: @@ -502,7 +502,7 @@ class MultiHeadedAttentionCrossAtt(nn.Module): q_h, k_h, v_h = self.forward_qkv(x, memory) q_h = q_h * self.d_k ** (-0.5) scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, memory_mask) + return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn) def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0): """Compute scaled dot product attention. diff --git a/funasr/models/seaco_paraformer/__init__.py b/funasr/models/seaco_paraformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py new file mode 100644 index 000000000..86aa7602a --- /dev/null +++ b/funasr/models/seaco_paraformer/model.py @@ -0,0 +1,512 @@ +import os +import logging +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import tempfile +import codecs +import requests +import re +import copy +import torch +import torch.nn as nn +import random +import numpy as np +import time +# from funasr.layers.abs_normalize import AbsNormalize +from funasr.losses.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +# from funasr.models.ctc import CTC +# from funasr.models.decoder.abs_decoder import AbsDecoder +# from funasr.models.e2e_asr_common import ErrorCalculator +# from funasr.models.encoder.abs_encoder import AbsEncoder +# from funasr.frontends.abs_frontend import AbsFrontend +# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.paraformer.cif_predictor import mae_loss +# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +# from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from funasr.metrics.compute_acc import th_accuracy +from funasr.train_utils.device_funcs import force_gatherable +# from funasr.models.base_model import FunASRModel +# from funasr.models.paraformer.cif_predictor import CifPredictorV3 +from funasr.models.paraformer.search import Hypothesis + + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield +from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter + +from funasr.models.paraformer.model import Paraformer +from funasr.utils.register import register_class, registry_tables + + +@register_class("model_classes", "SeacoParaformer") +class SeacoParaformer(Paraformer): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability + https://arxiv.org/abs/2308.03266 + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.inner_dim = kwargs.get("inner_dim", 256) + self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm") + bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) + bias_encoder_bid = kwargs.get("bias_encoder_bid", False) + seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0) + seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True) + + # bias encoder + if self.bias_encoder_type == 'lstm': + logging.warning("enable bias encoder sampling and contextual training") + self.bias_encoder = torch.nn.LSTM(self.inner_dim, + self.inner_dim, + 2, + batch_first=True, + dropout=bias_encoder_dropout_rate, + bidirectional=bias_encoder_bid) + if bias_encoder_bid: + self.lstm_proj = torch.nn.Linear(self.inner_dim*2, self.inner_dim) + else: + self.lstm_proj = None + self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) + elif self.bias_encoder_type == 'mean': + logging.warning("enable bias encoder sampling and contextual training") + self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) + else: + logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) + + # seaco decoder + seaco_decoder = kwargs.get("seaco_decoder", None) + if seaco_decoder is not None: + seaco_decoder_conf = kwargs.get("seaco_decoder_conf") + seaco_decoder_class = registry_tables.decoder_classes.get(seaco_decoder.lower()) + self.seaco_decoder = seaco_decoder_class( + vocab_size=self.vocab_size, + encoder_output_size=self.inner_dim, + **seaco_decoder_conf, + ) + self.hotword_output_layer = torch.nn.Linear(self.inner_dim, self.vocab_size) + self.criterion_seaco = LabelSmoothingLoss( + size=self.vocab_size, + padding_idx=self.ignore_id, + smoothing=seaco_lsm_weight, + normalize_length=seaco_length_normalized_loss, + ) + self.train_decoder = kwargs.get("train_decoder", False) + self.NO_BIAS = kwargs.get("NO_BIAS", 8377) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + hotword_pad = kwargs.get("hotword_pad") + hotword_lengths = kwargs.get("hotword_lengths") + dha_pad = kwargs.get("dha_pad") + + batch_size = speech.shape[0] + self.step_cur += 1 + # for data-parallel + text = text[:, : text_lengths.max()] + speech = speech[:, :speech_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if self.predictor_bias == 1: + _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) + ys_lengths = text_lengths + self.predictor_bias + + stats = dict() + loss_seaco = self._calc_seaco_loss(encoder_out, + encoder_out_lens, + ys_pad, + ys_lengths, + hotword_pad, + hotword_lengths, + dha_pad, + ) + if self.train_decoder: + loss_att, acc_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + loss = loss_seaco + loss_att + stats["loss_att"] = torch.clone(loss_att.detach()) + stats["acc_att"] = acc_att + else: + loss = loss_seaco + stats["loss_seaco"] = torch.clone(loss_seaco.detach()) + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _merge(self, cif_attended, dec_attended): + return cif_attended + dec_attended + + def _calc_seaco_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_lengths: torch.Tensor, + hotword_pad: torch.Tensor, + hotword_lengths: torch.Tensor, + dha_pad: torch.Tensor, + ): + # predictor forward + encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( + encoder_out.device) + pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, + ignore_id=self.ignore_id) + # decoder forward + decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) + selected = self._hotword_representation(hotword_pad, + hotword_lengths) + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + num_hot_word = contextual_info.shape[1] + _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + # dha core + cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, ys_lengths) + dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths) + merged = self._merge(cif_attended, dec_attended) + dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation + loss_att = self.criterion_seaco(dha_output, dha_pad) + return loss_att + + def _seaco_decode_with_ASF(self, + encoder_out, + encoder_out_lens, + sematic_embeds, + ys_pad_lens, + hw_list, + nfilter=50, + seaco_weight=1.0): + # decoder forward + decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) + decoder_pred = torch.log_softmax(decoder_out, dim=-1) + if hw_list is not None: + hw_lengths = [len(i) for i in hw_list] + hw_list_ = [torch.Tensor(i).long() for i in hw_list] + hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) + selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + num_hot_word = contextual_info.shape[1] + _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + + # ASF Core + if nfilter > 0 and nfilter < num_hot_word: + for dec in self.seaco_decoder.decoders: + dec.reserve_attn = True + # cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) + dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) + # cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist() + hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1] + # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device) + dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist() + add_filter = dec_filter + add_filter.append(len(hw_list_pad)-1) + # filter hotword embedding + selected = selected[add_filter] + # again + contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + num_hot_word = contextual_info.shape[1] + _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + for dec in self.seaco_decoder.decoders: + dec.attn_mat = [] + dec.reserve_attn = False + + # SeACo Core + cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) + dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) + merged = self._merge(cif_attended, dec_attended) + + dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation + dha_pred = torch.log_softmax(dha_output, dim=-1) + # import pdb; pdb.set_trace() + def _merge_res(dec_output, dha_output): + lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) + dha_ids = dha_output.max(-1)[-1][0] + dha_mask = (dha_ids == 8377).int().unsqueeze(-1) + a = (1 - lmbd) / lmbd + b = 1 / lmbd + a, b = a.to(dec_output.device), b.to(dec_output.device) + dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1) + # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) + logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) + return logits + merged_pred = _merge_res(decoder_pred, dha_pred) + return merged_pred + else: + return decoder_pred + + def _hotword_representation(self, + hotword_pad, + hotword_lengths): + if self.bias_encoder_type != 'lstm': + logging.error("Unsupported bias encoder type") + hw_embed = self.decoder.embed(hotword_pad) + hw_embed, (_, _) = self.bias_encoder(hw_embed) + if self.lstm_proj is not None: + hw_embed = self.lstm_proj(hw_embed) + _ind = np.arange(0, hw_embed.shape[0]).tolist() + selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]] + return selected + + def generate(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + + # init beamsearch + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None + is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None + if self.beam_search is None and (is_use_lm or is_use_ctc): + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + meta_data = {} + + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data[ + "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + + # hotword + self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) + + # Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # predictor + predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) + pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ + predictor_outs[2], predictor_outs[3] + pre_token_length = pre_token_length.round().long() + if torch.max(pre_token_length) < 1: + return [] + + + decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, + pre_acoustic_embeds, + pre_token_length, + hw_list=self.hotword_list) + # decoder_out, _ = decoder_outs[0], decoder_outs[1] + + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = encoder_out[i, :encoder_out_lens[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), + minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.sos] + yseq.tolist() + [self.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + for nbest_idx, hyp in enumerate(nbest_hyps): + ibest_writer = None + if ibest_writer is None and kwargs.get("output_dir") is not None: + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{nbest_idx + 1}best_recog"] + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list( + filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + if tokenizer is not None: + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + text = tokenizer.tokens2text(token) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} + + if ibest_writer is not None: + ibest_writer["token"][key[i]] = " ".join(token) + ibest_writer["text"][key[i]] = text + ibest_writer["text_postprocessed"][key[i]] = text_postprocessed + else: + result_i = {"key": key[i], "token_int": token_int} + results.append(result_i) + + return results, meta_data + + + def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None): + def load_seg_dict(seg_dict_file): + seg_dict = {} + assert isinstance(seg_dict_file, str) + with open(seg_dict_file, "r", encoding="utf8") as f: + lines = f.readlines() + for line in lines: + s = line.strip().split() + key = s[0] + value = s[1:] + seg_dict[key] = " ".join(value) + return seg_dict + + def seg_tokenize(txt, seg_dict): + pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') + out_txt = "" + for word in txt: + word = word.lower() + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + if pattern.match(word): + for char in word: + if char in seg_dict: + out_txt += seg_dict[char] + " " + else: + out_txt += "" + " " + else: + out_txt += "" + " " + return out_txt.strip().split() + + seg_dict = None + if frontend.cmvn_file is not None: + model_dir = os.path.dirname(frontend.cmvn_file) + seg_dict_file = os.path.join(model_dir, 'seg_dict') + if os.path.exists(seg_dict_file): + seg_dict = load_seg_dict(seg_dict_file) + else: + seg_dict = None + # for None + if hotword_list_or_file is None: + hotword_list = None + # for local txt inputs + elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): + logging.info("Attempting to parse hotwords from local txt...") + hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hw_list = hw.split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_str_list.append(hw) + hotword_list.append(tokenizer.tokens2ids(hw_list)) + hotword_list.append([self.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for url, download and generate txt + elif hotword_list_or_file.startswith('http'): + logging.info("Attempting to parse hotwords from url...") + work_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(work_dir): + os.makedirs(work_dir) + text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) + local_file = requests.get(hotword_list_or_file) + open(text_file_path, "wb").write(local_file.content) + hotword_list_or_file = text_file_path + hotword_list = [] + hotword_str_list = [] + with codecs.open(hotword_list_or_file, 'r') as fin: + for line in fin.readlines(): + hw = line.strip() + hw_list = hw.split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_str_list.append(hw) + hotword_list.append(tokenizer.tokens2ids(hw_list)) + hotword_list.append([self.sos]) + hotword_str_list.append('') + logging.info("Initialized hotword list from file: {}, hotword list: {}." + .format(hotword_list_or_file, hotword_str_list)) + # for text str input + elif not hotword_list_or_file.endswith('.txt'): + logging.info("Attempting to parse hotwords as str...") + hotword_list = [] + hotword_str_list = [] + for hw in hotword_list_or_file.strip().split(): + hotword_str_list.append(hw) + hw_list = hw.strip().split() + if seg_dict is not None: + hw_list = seg_tokenize(hw_list, seg_dict) + hotword_list.append(tokenizer.tokens2ids(hw_list)) + hotword_list.append([self.sos]) + hotword_str_list.append('') + logging.info("Hotword list: {}.".format(hotword_str_list)) + else: + hotword_list = None + return hotword_list + diff --git a/funasr/models/seaco_paraformer/template.yaml b/funasr/models/seaco_paraformer/template.yaml new file mode 100644 index 000000000..266386ffe --- /dev/null +++ b/funasr/models/seaco_paraformer/template.yaml @@ -0,0 +1,151 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.utils.register import registry_tables +# registry_tables.print() + +# network architecture +model: SeacoParaformer +model_conf: + ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + predictor_weight: 1.0 + predictor_bias: 1 + sampling_ratio: 0.75 + inner_dim: 512 + bias_encoder_type: lstm + bias_encoder_bid: false + seaco_lsm_weight: 0.1 + seaco_length_normal: true + train_decoder: false + NO_BIAS: 8377 + +# encoder +encoder: SANMEncoder +encoder_conf: + output_size: 512 + attention_heads: 4 + linear_units: 2048 + num_blocks: 50 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + +# decoder +decoder: ParaformerSANMDecoder +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 16 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 16 + kernel_size: 11 + sanm_shfit: 0 + +# seaco decoder +seaco_decoder: ParaformerSANMDecoder +seaco_decoder_conf: + attention_heads: 4 + linear_units: 1024 + num_blocks: 4 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + kernel_size: 21 + sanm_shfit: 0 + use_output_layer: false + wo_input_layer: true + +predictor: CifPredictorV2 +predictor_conf: + idim: 512 + threshold: 1.0 + l_order: 1 + r_order: 1 + tail_threshold: 0.45 + +# frontend related +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + dither: 0.0 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + val_scheduler_criterion: + - valid + - acc + best_model_criterion: + - - valid + - acc + - max + keep_nbest_models: 10 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: DynamicBatchLocalShuffleSampler + batch_type: example # example or length + batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +normalize: null