mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
commit
7acfa5efd9
@ -54,7 +54,7 @@ class Speech2Diarization:
|
||||
self,
|
||||
diar_train_config: Union[Path, str] = None,
|
||||
diar_model_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
streaming: bool = False,
|
||||
@ -114,9 +114,19 @@ class Speech2Diarization:
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
# process oov
|
||||
seq = np.array([int(x) for x in seq])
|
||||
new_seq = []
|
||||
for i, x in enumerate(seq):
|
||||
if x < 2 ** vec_dim:
|
||||
new_seq.append(x)
|
||||
else:
|
||||
idx_list = np.where(seq < 2 ** vec_dim)[0]
|
||||
idx = np.abs(idx_list - i).argmin()
|
||||
new_seq.append(seq[idx_list[idx]])
|
||||
return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
|
||||
|
||||
def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
|
||||
def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
|
||||
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
|
||||
# upsampling outputs to match inputs
|
||||
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
|
||||
@ -127,8 +137,14 @@ class Speech2Diarization:
|
||||
).squeeze(1).long()
|
||||
logits_idx = logits_idx[0].tolist()
|
||||
pse_labels = [self.token_list[x] for x in logits_idx]
|
||||
if output_format == "pse_labels":
|
||||
return pse_labels, None
|
||||
|
||||
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
|
||||
multi_labels = self.smooth_multi_labels(multi_labels)
|
||||
if output_format == "binary_labels":
|
||||
return multi_labels, None
|
||||
|
||||
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
|
||||
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
|
||||
results = OrderedDict()
|
||||
@ -149,6 +165,7 @@ class Speech2Diarization:
|
||||
self,
|
||||
speech: Union[torch.Tensor, np.ndarray],
|
||||
profile: Union[torch.Tensor, np.ndarray],
|
||||
output_format: str = "speaker_turn"
|
||||
):
|
||||
"""Inference
|
||||
|
||||
@ -178,7 +195,7 @@ class Speech2Diarization:
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
logits = self.diar_model.prediction_forward(**batch)
|
||||
results, pse_labels = self.post_processing(logits, profile.shape[1])
|
||||
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
|
||||
|
||||
return results, pse_labels
|
||||
|
||||
@ -367,7 +384,7 @@ def inference_modelscope(
|
||||
pse_label_writer = open("{}/labels.txt".format(output_path), "w")
|
||||
logging.info("Start to diarize...")
|
||||
result_list = []
|
||||
for keys, batch in loader:
|
||||
for idx, (keys, batch) in enumerate(loader):
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
@ -385,6 +402,9 @@ def inference_modelscope(
|
||||
pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
|
||||
pse_label_writer.flush()
|
||||
|
||||
if idx % 100 == 0:
|
||||
logging.info("Processing {:5d}: {}".format(idx, key))
|
||||
|
||||
if output_path is not None:
|
||||
output_writer.close()
|
||||
pse_label_writer.close()
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Dict
|
||||
from typing import Iterator
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import List
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
@ -129,7 +130,7 @@ class IterableESPnetDataset(IterableDataset):
|
||||
non_iterable_list = []
|
||||
self.path_name_type_list = []
|
||||
|
||||
if not isinstance(path_name_type_list[0], Tuple):
|
||||
if not isinstance(path_name_type_list[0], (Tuple, List)):
|
||||
path = path_name_type_list[0]
|
||||
name = path_name_type_list[1]
|
||||
_type = path_name_type_list[2]
|
||||
|
||||
@ -59,7 +59,8 @@ class DiarSondModel(AbsESPnetModel):
|
||||
normalize_speech_speaker: bool = False,
|
||||
ignore_id: int = -1,
|
||||
speaker_discrimination_loss_weight: float = 1.0,
|
||||
inter_score_loss_weight: float = 0.0
|
||||
inter_score_loss_weight: float = 0.0,
|
||||
inputs_type: str = "raw",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
@ -86,14 +87,12 @@ class DiarSondModel(AbsESPnetModel):
|
||||
)
|
||||
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
|
||||
self.pse_embedding = self.generate_pse_embedding()
|
||||
# self.register_buffer("pse_embedding", pse_embedding)
|
||||
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
||||
# self.register_buffer("power_weight", power_weight)
|
||||
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
||||
# self.register_buffer("int_token_arr", int_token_arr)
|
||||
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
|
||||
self.inter_score_loss_weight = inter_score_loss_weight
|
||||
self.forward_steps = 0
|
||||
self.inputs_type = inputs_type
|
||||
|
||||
def generate_pse_embedding(self):
|
||||
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
|
||||
@ -125,9 +124,14 @@ class DiarSondModel(AbsESPnetModel):
|
||||
binary_labels: (Batch, frames, max_spk_num)
|
||||
binary_labels_lengths: (Batch,)
|
||||
"""
|
||||
assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
|
||||
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
|
||||
batch_size = speech.shape[0]
|
||||
self.forward_steps = self.forward_steps + 1
|
||||
if self.pse_embedding.device != speech.device:
|
||||
self.pse_embedding = self.pse_embedding.to(speech.device)
|
||||
self.power_weight = self.power_weight.to(speech.device)
|
||||
self.int_token_arr = self.int_token_arr.to(speech.device)
|
||||
|
||||
# 1. Network forward
|
||||
pred, inter_outputs = self.prediction_forward(
|
||||
speech, speech_lengths,
|
||||
@ -149,9 +153,13 @@ class DiarSondModel(AbsESPnetModel):
|
||||
# the sequence length of 'pred' might be slightly less than the
|
||||
# length of 'spk_labels'. Here we force them to be equal.
|
||||
length_diff_tolerance = 2
|
||||
length_diff = pse_labels.shape[1] - pred.shape[1]
|
||||
if 0 < length_diff <= length_diff_tolerance:
|
||||
pse_labels = pse_labels[:, 0: pred.shape[1]]
|
||||
length_diff = abs(pse_labels.shape[1] - pred.shape[1])
|
||||
if length_diff <= length_diff_tolerance:
|
||||
min_len = min(pred.shape[1], pse_labels.shape[1])
|
||||
pse_labels = pse_labels[:, :min_len]
|
||||
pred = pred[:, :min_len]
|
||||
cd_score = cd_score[:, :min_len]
|
||||
ci_score = ci_score[:, :min_len]
|
||||
|
||||
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
|
||||
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
|
||||
@ -299,7 +307,7 @@ class DiarSondModel(AbsESPnetModel):
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.encoder is not None:
|
||||
if self.encoder is not None and self.inputs_type == "raw":
|
||||
speech, speech_lengths = self.encode(speech, speech_lengths)
|
||||
speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
|
||||
speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
|
||||
|
||||
@ -507,7 +507,7 @@ class DiarTask(AbsTask):
|
||||
config_file: Union[Path, str] = None,
|
||||
model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
):
|
||||
"""Build model from the files.
|
||||
|
||||
@ -562,6 +562,7 @@ class DiarTask(AbsTask):
|
||||
model.load_state_dict(model_dict)
|
||||
else:
|
||||
model_dict = torch.load(model_file, map_location=device)
|
||||
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
|
||||
model.load_state_dict(model_dict)
|
||||
if model_name_pth is not None and not os.path.exists(model_name_pth):
|
||||
torch.save(model_dict, model_name_pth)
|
||||
@ -569,6 +570,20 @@ class DiarTask(AbsTask):
|
||||
|
||||
return model, args
|
||||
|
||||
@classmethod
|
||||
def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
|
||||
from collections import OrderedDict
|
||||
new_dict = OrderedDict()
|
||||
for key, value in src_dict.items():
|
||||
if key in dest_dict:
|
||||
new_dict[key] = value
|
||||
else:
|
||||
logging.info("{} is no longer needed in this model.".format(key))
|
||||
for key, value in dest_dict.items():
|
||||
if key not in new_dict:
|
||||
logging.warning("{} is missed in checkpoint.".format(key))
|
||||
return new_dict
|
||||
|
||||
@classmethod
|
||||
def convert_tf2torch(
|
||||
cls,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user