diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 79cb61496..097b23a57 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -76,7 +76,7 @@ class DiarEENDOLAModel(AbsESPnetModel): def forward_post_net(self, logits, ilens): maxlen = torch.max(ilens).to(torch.int).item() logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) - logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False) + logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False) outputs, (_, _) = self.postnet(logits) outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] @@ -231,7 +231,7 @@ class DiarEENDOLAModel(AbsESPnetModel): pred[i] = pred[i - 1] else: pred[i] = 0 - pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] + pred = [self.inv_mapping_func(i) for i in pred] decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] decisions = torch.from_numpy( np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( @@ -239,5 +239,15 @@ class DiarEENDOLAModel(AbsESPnetModel): decisions = decisions[:, :n_speaker] return decisions + def inv_mapping_func(self, label): + + if not isinstance(label, int): + label = int(label) + if label in self.mapping_dict['label2dec'].keys(): + num = self.mapping_dict['label2dec'][label] + else: + num = -1 + return num + def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: pass \ No newline at end of file diff --git a/funasr/modules/eend_ola/encoder_decoder_attractor.py b/funasr/modules/eend_ola/encoder_decoder_attractor.py index 4e599ab31..45ac98219 100644 --- a/funasr/modules/eend_ola/encoder_decoder_attractor.py +++ b/funasr/modules/eend_ola/encoder_decoder_attractor.py @@ -2,8 +2,7 @@ import numpy as np import torch import torch.nn.functional as F from torch import nn -from modelscope.utils.logger import get_logger -logger = get_logger() + class EncoderDecoderAttractor(nn.Module): @@ -17,14 +16,12 @@ class EncoderDecoderAttractor(nn.Module): self.n_units = n_units def forward_core(self, xs, zeros): - logger.info("xs: ".format(xs)) - ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.float32).to(xs[0].device) - logger.info("ilens: ".format(ilens)) + ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.int64) xs = [self.enc0_dropout(x) for x in xs] xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) xs = nn.utils.rnn.pack_padded_sequence(xs, ilens, batch_first=True, enforce_sorted=False) _, (hx, cx) = self.encoder(xs) - zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.float32).to(zeros[0].device) + zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.int64) max_zlen = torch.max(zlens).to(torch.int).item() zeros = [self.enc0_dropout(z) for z in zeros] zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1) @@ -50,4 +47,4 @@ class EncoderDecoderAttractor(nn.Module): zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs] attractors = self.forward_core(xs, zeros) probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors] - return attractors, probs \ No newline at end of file + return attractors, probs