mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_infer' of https://github.com/alibaba-damo-academy/FunASR into dev_infer
This commit is contained in:
commit
1499592e7d
@ -7,7 +7,7 @@ We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer mo
|
||||
- `gpu_num`: the number of GPUs used for training
|
||||
- `gpu_inference`: whether to use GPUs for decoding
|
||||
- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
|
||||
- `data_aishell`: the raw path of AISHELL-1 dataset
|
||||
- `raw_data`: the raw path of AISHELL-1 dataset
|
||||
- `feats_dir`: the path for saving processed data
|
||||
- `nj`: the number of jobs for data preparation
|
||||
- `speed_perturb`: the range of speech perturbed
|
||||
@ -15,7 +15,7 @@ We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer mo
|
||||
- `tag`: the suffix of experimental result directory
|
||||
|
||||
## Stage 0: Data preparation
|
||||
This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
|
||||
This stage processes raw AISHELL-1 dataset `$raw_data` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$raw_data`. The examples of `wav.scp` and `text` are as follows:
|
||||
* `wav.scp`
|
||||
```
|
||||
BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
|
||||
@ -32,28 +32,8 @@ BAC009S0002W0124 自 六 月 底 呼 和 浩 特 市 率 先 宣 布 取 消 限
|
||||
```
|
||||
These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
|
||||
|
||||
## Stage 1: Feature Generation
|
||||
This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
|
||||
* `feats.scp`
|
||||
```
|
||||
...
|
||||
BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
|
||||
...
|
||||
```
|
||||
Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
|
||||
* `speech_shape`
|
||||
```
|
||||
...
|
||||
BAC009S0002W0122_sp0.9 665,80
|
||||
...
|
||||
```
|
||||
* `text_shape`
|
||||
```
|
||||
...
|
||||
BAC009S0002W0122_sp0.9 15
|
||||
...
|
||||
```
|
||||
These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
|
||||
## Stage 1: Feature and CMVN Generation
|
||||
This stage computes CMVN based on `train` dataset, which is used in the following stages. Users can set `nj` to control the number of jobs for computing CMVN. The generated CMVN file is saved as `$feats_dir/data/train/cmvn/cmvn.mvn`.
|
||||
|
||||
## Stage 2: Dictionary Preparation
|
||||
This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
|
||||
@ -117,7 +97,7 @@ We support CPU and GPU decoding in FunASR. For CPU decoding, you should set `gpu
|
||||
|
||||
* Performance
|
||||
|
||||
We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
|
||||
We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` results. The following is an example of `text.cer`:
|
||||
* `text.cer`
|
||||
```
|
||||
...
|
||||
|
||||
@ -47,6 +47,7 @@ model_conf:
|
||||
length_normalized_loss: false
|
||||
predictor_weight: 1.0
|
||||
sampling_ratio: 0.4
|
||||
use_1st_decoder_loss: true
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.lm.abs_model import LanguageModel
|
||||
from funasr.lm.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.lm.transformer_lm import TransformerLM
|
||||
from funasr.train.abs_model import AbsLM
|
||||
from funasr.train.abs_model import LanguageModel
|
||||
from funasr.models.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.models.transformer_lm import TransformerLM
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
|
||||
|
||||
@ -1,158 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract LM class
|
||||
|
||||
To share the loss calculation way among different models,
|
||||
We uses delegate pattern here:
|
||||
The instance of this class should be passed to "LanguageModel"
|
||||
|
||||
>>> from funasr.lm.abs_model import AbsLM
|
||||
>>> lm = AbsLM()
|
||||
>>> model = LanguageESPnetModel(lm=lm)
|
||||
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, hidden: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanguageModel(FunASRModel):
|
||||
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
max_lengths: int
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
if max_length is None:
|
||||
text = text[:, : text_lengths.max()]
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
|
||||
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
||||
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
|
||||
x = F.pad(text, [1, 0], "constant", self.sos)
|
||||
t = F.pad(text, [0, 1], "constant", self.ignore_id)
|
||||
for i, l in enumerate(text_lengths):
|
||||
t[i, l] = self.eos
|
||||
x_lengths = text_lengths + 1
|
||||
|
||||
# 2. Forward Language model
|
||||
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
|
||||
y, _ = self.lm(x, None)
|
||||
|
||||
# 3. Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
|
||||
else:
|
||||
nll.masked_fill_(
|
||||
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||
0.0,
|
||||
)
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, x_lengths
|
||||
|
||||
def batchify_nll(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll) from transformer language model
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
|
||||
"""
|
||||
total_num = text.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll, x_lengths = self.nll(text, text_lengths)
|
||||
else:
|
||||
nlls = []
|
||||
x_lengths = []
|
||||
max_length = text_lengths.max()
|
||||
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_text = text[start_idx:end_idx, :]
|
||||
batch_text_lengths = text_lengths[start_idx:end_idx]
|
||||
# batch_nll: [B * T]
|
||||
batch_nll, batch_x_lengths = self.nll(
|
||||
batch_text, batch_text_lengths, max_length=max_length
|
||||
)
|
||||
nlls.append(batch_nll)
|
||||
x_lengths.append(batch_x_lengths)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nlls)
|
||||
x_lengths = torch.cat(x_lengths)
|
||||
assert nll.size(0) == total_num
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
def forward(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, text_lengths)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
@ -78,6 +78,7 @@ class Paraformer(FunASRModel):
|
||||
share_embedding: bool = False,
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
use_1st_decoder_loss: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
@ -144,6 +145,8 @@ class Paraformer(FunASRModel):
|
||||
if self.share_embedding:
|
||||
self.decoder.embed = None
|
||||
|
||||
self.use_1st_decoder_loss = use_1st_decoder_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
@ -179,7 +182,7 @@ class Paraformer(FunASRModel):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
loss_pre = None
|
||||
stats = dict()
|
||||
@ -220,7 +223,7 @@ class Paraformer(FunASRModel):
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
|
||||
loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
@ -232,8 +235,12 @@ class Paraformer(FunASRModel):
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
||||
|
||||
if self.use_1st_decoder_loss and pre_loss_att is not None:
|
||||
loss = loss + pre_loss_att
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
|
||||
stats["acc"] = acc_att
|
||||
stats["cer"] = cer_att
|
||||
stats["wer"] = wer_att
|
||||
@ -456,11 +463,16 @@ class Paraformer(FunASRModel):
|
||||
|
||||
# 0. sampler
|
||||
decoder_out_1st = None
|
||||
pre_loss_att = None
|
||||
if self.sampling_ratio > 0.0:
|
||||
if self.step_cur < 2:
|
||||
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
if self.use_1st_decoder_loss:
|
||||
sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
else:
|
||||
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
else:
|
||||
if self.step_cur < 2:
|
||||
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
@ -490,7 +502,7 @@ class Paraformer(FunASRModel):
|
||||
ys_hat = decoder_out_1st.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att, loss_pre
|
||||
return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
|
||||
|
||||
def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
|
||||
|
||||
@ -523,6 +535,37 @@ class Paraformer(FunASRModel):
|
||||
input_mask_expand_dim, 0)
|
||||
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
|
||||
|
||||
def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
|
||||
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
|
||||
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
|
||||
if self.share_embedding:
|
||||
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
|
||||
else:
|
||||
ys_pad_embed = self.decoder.embed(ys_pad_masked)
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
|
||||
)
|
||||
pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
|
||||
decoder_out, _ = decoder_outs[0], decoder_outs[1]
|
||||
pred_tokens = decoder_out.argmax(-1)
|
||||
nonpad_positions = ys_pad.ne(self.ignore_id)
|
||||
seq_lens = (nonpad_positions).sum(1)
|
||||
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
|
||||
input_mask = torch.ones_like(nonpad_positions)
|
||||
bsz, seq_len = ys_pad.size()
|
||||
for li in range(bsz):
|
||||
target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
|
||||
if target_num > 0:
|
||||
input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
|
||||
input_mask = input_mask.eq(1)
|
||||
input_mask = input_mask.masked_fill(~nonpad_positions, False)
|
||||
input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
|
||||
|
||||
sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
|
||||
input_mask_expand_dim, 0)
|
||||
|
||||
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
|
||||
@ -5,8 +5,7 @@ from typing import Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.train.abs_model import AbsLM
|
||||
|
||||
|
||||
class SequentialRNNLM(AbsLM):
|
||||
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder_s0 as Encoder
|
||||
from funasr.modules.mask import subsequent_mask
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.train.abs_model import AbsLM
|
||||
|
||||
|
||||
class TransformerLM(AbsLM):
|
||||
@ -14,10 +14,10 @@ from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.lm.abs_model import LanguageModel
|
||||
from funasr.lm.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.lm.transformer_lm import TransformerLM
|
||||
from funasr.train.abs_model import AbsLM
|
||||
from funasr.train.abs_model import LanguageModel
|
||||
from funasr.models.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.models.transformer_lm import TransformerLM
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@ -14,6 +14,142 @@ from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
|
||||
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract LM class
|
||||
|
||||
To share the loss calculation way among different models,
|
||||
We uses delegate pattern here:
|
||||
The instance of this class should be passed to "LanguageModel"
|
||||
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, hidden: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanguageModel(FunASRModel):
|
||||
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
max_lengths: int
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
if max_length is None:
|
||||
text = text[:, : text_lengths.max()]
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
|
||||
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
||||
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
|
||||
x = F.pad(text, [1, 0], "constant", self.sos)
|
||||
t = F.pad(text, [0, 1], "constant", self.ignore_id)
|
||||
for i, l in enumerate(text_lengths):
|
||||
t[i, l] = self.eos
|
||||
x_lengths = text_lengths + 1
|
||||
|
||||
# 2. Forward Language model
|
||||
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
|
||||
y, _ = self.lm(x, None)
|
||||
|
||||
# 3. Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
|
||||
else:
|
||||
nll.masked_fill_(
|
||||
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||
0.0,
|
||||
)
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, x_lengths
|
||||
|
||||
def batchify_nll(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll) from transformer language model
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
|
||||
"""
|
||||
total_num = text.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll, x_lengths = self.nll(text, text_lengths)
|
||||
else:
|
||||
nlls = []
|
||||
x_lengths = []
|
||||
max_length = text_lengths.max()
|
||||
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_text = text[start_idx:end_idx, :]
|
||||
batch_text_lengths = text_lengths[start_idx:end_idx]
|
||||
# batch_nll: [B * T]
|
||||
batch_nll, batch_x_lengths = self.nll(
|
||||
batch_text, batch_text_lengths, max_length=max_length
|
||||
)
|
||||
nlls.append(batch_nll)
|
||||
x_lengths.append(batch_x_lengths)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nlls)
|
||||
x_lengths = torch.cat(x_lengths)
|
||||
assert nll.size(0) == total_num
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
def forward(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, text_lengths)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
|
||||
class PunctuationModel(FunASRModel):
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user