This commit is contained in:
语帆 2024-02-28 14:52:05 +08:00
parent ab4a31201c
commit 2bffe1d539
3 changed files with 8 additions and 9 deletions

View File

@ -6,7 +6,6 @@ python -m funasr.bin.inference \
--config-name="config.yaml" \ --config-name="config.yaml" \
++init_param=${file_dir}/model.pb \ ++init_param=${file_dir}/model.pb \
++tokenizer_conf.token_list=${file_dir}/tokens.txt \ ++tokenizer_conf.token_list=${file_dir}/tokens.txt \
++frontend_conf.cmvn_file=${file_dir}/am.mvn \
++input=[${file_dir}/wav.scp,${file_dir}/ocr_text] \ ++input=[${file_dir}/wav.scp,${file_dir}/ocr_text] \
+data_type='["kaldi_ark", "text"]' \ +data_type='["kaldi_ark", "text"]' \
++tokenizer_conf.bpemodel=${file_dir}/bpe.model \ ++tokenizer_conf.bpemodel=${file_dir}/bpe.model \

View File

@ -21,6 +21,7 @@ from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables from funasr.register import tables
import pdb import pdb
@tables.register("model_classes", "LCBNet") @tables.register("model_classes", "LCBNet")
class LCBNet(nn.Module): class LCBNet(nn.Module):
@ -92,6 +93,7 @@ class LCBNet(nn.Module):
bias_predictor_class = tables.encoder_classes.get(bias_predictor) bias_predictor_class = tables.encoder_classes.get(bias_predictor)
bias_predictor = bias_predictor_class(**bias_predictor_conf) bias_predictor = bias_predictor_class(**bias_predictor_conf)
if decoder is not None: if decoder is not None:
decoder_class = tables.decoder_classes.get(decoder) decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class( decoder = decoder_class(
@ -272,15 +274,15 @@ class LCBNet(nn.Module):
ind: int ind: int
""" """
with autocast(False): with autocast(False):
pdb.set_trace()
# Data augmentation # Data augmentation
if self.specaug is not None and self.training: if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths) speech, speech_lengths = self.specaug(speech, speech_lengths)
pdb.set_trace()
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None: if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths) speech, speech_lengths = self.normalize(speech, speech_lengths)
pdb.set_trace()
# Forward encoder # Forward encoder
# feats: (Batch, Length, Dim) # feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2) # -> encoder_out: (Batch, Length2, Dim2)
@ -297,7 +299,7 @@ class LCBNet(nn.Module):
if intermediate_outs is not None: if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens return (encoder_out, intermediate_outs), encoder_out_lens
pdb.set_trace()
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens
def _calc_att_loss( def _calc_att_loss(
@ -442,6 +444,7 @@ class LCBNet(nn.Module):
speech = speech.to(device=kwargs["device"]) speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"])
pdb.set_trace()
# Encoder # Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple): if isinstance(encoder_out, tuple):

View File

@ -108,10 +108,7 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None,
data_list.append(data_i) data_list.append(data_i)
data_len.append(data_i.shape[0]) data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N] data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
pdb.set_trace()
data, data_len = frontend(data, data_len, **kwargs) data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)): if isinstance(data_len, (list, tuple)):