mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update ola
This commit is contained in:
parent
b6126fd539
commit
229efa6250
@ -26,6 +26,13 @@ else:
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def pad_attractor(att, max_n_speakers):
|
||||||
|
C, D = att.shape
|
||||||
|
if C < max_n_speakers:
|
||||||
|
att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
|
||||||
|
return att
|
||||||
|
|
||||||
|
|
||||||
class DiarEENDOLAModel(AbsESPnetModel):
|
class DiarEENDOLAModel(AbsESPnetModel):
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
@ -53,6 +60,26 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|||||||
self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
||||||
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
||||||
|
|
||||||
|
def forward_encoder(self, xs, ilens):
|
||||||
|
xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
|
||||||
|
pad_shape = xs.shape
|
||||||
|
xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
|
||||||
|
xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
|
||||||
|
emb = self.encoder(xs, xs_mask)
|
||||||
|
emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
|
||||||
|
emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def forward_post_net(self, logits, ilens):
|
||||||
|
maxlen = torch.max(ilens).to(torch.int).item()
|
||||||
|
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
|
||||||
|
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
|
||||||
|
outputs, (_, _) = self.PostNet(logits)
|
||||||
|
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
|
||||||
|
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
|
||||||
|
outputs = [self.output_layer(output) for output in outputs]
|
||||||
|
return outputs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
@ -156,51 +183,45 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|||||||
def estimate_sequential(self,
|
def estimate_sequential(self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
||||||
n_speakers: int,
|
n_speakers: int = None,
|
||||||
shuffle: bool,
|
shuffle: bool = True,
|
||||||
threshold: float,
|
threshold: float = 0.5,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
||||||
emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)]
|
emb = self.forward_encoder(speech, speech_lengths)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
orders = [np.arange(e.shape[0]) for e in emb]
|
orders = [np.arange(e.shape[0]) for e in emb]
|
||||||
for order in orders:
|
for order in orders:
|
||||||
np.random.shuffle(order)
|
np.random.shuffle(order)
|
||||||
# e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换
|
|
||||||
# attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)]
|
|
||||||
# probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ]
|
|
||||||
attractors, probs = self.eda.estimate(
|
attractors, probs = self.eda.estimate(
|
||||||
[e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)])
|
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
|
||||||
else:
|
else:
|
||||||
attractors, probs = self.eda.estimate(emb)
|
attractors, probs = self.eda.estimate(emb)
|
||||||
attractors_active = []
|
attractors_active = []
|
||||||
for p, att, e in zip(probs, attractors, emb):
|
for p, att, e in zip(probs, attractors, emb):
|
||||||
if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys
|
if n_speakers and n_speakers >= 0:
|
||||||
# TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定
|
|
||||||
# raise NotImplementedError
|
|
||||||
att = att[:n_speakers, ]
|
att = att[:n_speakers, ]
|
||||||
attractors_active.append(att)
|
attractors_active.append(att)
|
||||||
elif threshold is not None:
|
elif threshold is not None:
|
||||||
silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数
|
silence = torch.nonzero(p < threshold)[0]
|
||||||
n_spk = silence[0] if silence.size else None
|
n_spk = silence[0] if silence.size else None
|
||||||
att = att[:n_spk, ]
|
att = att[:n_spk, ]
|
||||||
attractors_active.append(att)
|
attractors_active.append(att)
|
||||||
else:
|
else:
|
||||||
NotImplementedError('n_speakers or th has to be given.')
|
NotImplementedError('n_speakers or threshold has to be given.')
|
||||||
raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB]
|
raw_n_speakers = [att.shape[0] for att in attractors_active]
|
||||||
attractors = [
|
attractors = [
|
||||||
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
|
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
|
||||||
for att in attractors_active]
|
for att in attractors_active]
|
||||||
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
|
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
|
||||||
# ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)]
|
logits = self.forward_post_net(ys, speech_lengths)
|
||||||
logits = self.cal_postnet(ys, self.max_n_speaker)
|
|
||||||
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
|
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
|
||||||
zip(logits, raw_n_speakers)]
|
zip(logits, raw_n_speakers)]
|
||||||
|
|
||||||
return ys, emb, attractors, raw_n_speakers
|
return ys, emb, attractors, raw_n_speakers
|
||||||
|
|
||||||
def recover_y_from_powerlabel(self, logit, n_speaker):
|
def recover_y_from_powerlabel(self, logit, n_speaker):
|
||||||
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, )
|
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
|
||||||
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
||||||
for i in oov_index:
|
for i in oov_index:
|
||||||
if i > 0:
|
if i > 0:
|
||||||
@ -208,7 +229,6 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|||||||
else:
|
else:
|
||||||
pred[i] = 0
|
pred[i] = 0
|
||||||
pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
|
pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
|
||||||
# print(pred)
|
|
||||||
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
||||||
decisions = torch.from_numpy(
|
decisions = torch.from_numpy(
|
||||||
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
|
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user