mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
273 lines
9.5 KiB
Python
Executable File
273 lines
9.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- encoding: utf-8 -*-
|
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
import logging
|
|
import os
|
|
from collections import OrderedDict
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from typing import Optional
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from scipy.ndimage import median_filter
|
|
from torch.nn import functional as F
|
|
|
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
|
from funasr.tasks.diar import DiarTask
|
|
from funasr.build_utils.build_model_from_file import build_model_from_file
|
|
from funasr.torch_utils.device_funcs import to_device
|
|
from funasr.utils.misc import statistic_model_parameters
|
|
|
|
|
|
class Speech2DiarizationEEND:
|
|
"""Speech2Diarlization class
|
|
|
|
Examples:
|
|
>>> import soundfile
|
|
>>> import numpy as np
|
|
>>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
|
|
>>> profile = np.load("profiles.npy")
|
|
>>> audio, rate = soundfile.read("speech.wav")
|
|
>>> speech2diar(audio, profile)
|
|
{"spk1": [(int, int), ...], ...}
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
diar_train_config: Union[Path, str] = None,
|
|
diar_model_file: Union[Path, str] = None,
|
|
device: str = "cpu",
|
|
dtype: str = "float32",
|
|
):
|
|
|
|
# 1. Build Diarization model
|
|
diar_model, diar_train_args = build_model_from_file(
|
|
config_file=diar_train_config,
|
|
model_file=diar_model_file,
|
|
device=device,
|
|
task_name="diar",
|
|
mode="eend-ola",
|
|
)
|
|
frontend = None
|
|
if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
|
|
frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
|
|
|
|
# set up seed for eda
|
|
np.random.seed(diar_train_args.seed)
|
|
torch.manual_seed(diar_train_args.seed)
|
|
torch.cuda.manual_seed(diar_train_args.seed)
|
|
os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
|
|
logging.info("diar_model: {}".format(diar_model))
|
|
logging.info("diar_train_args: {}".format(diar_train_args))
|
|
diar_model.to(dtype=getattr(torch, dtype)).eval()
|
|
|
|
self.diar_model = diar_model
|
|
self.diar_train_args = diar_train_args
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.frontend = frontend
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
speech: Union[torch.Tensor, np.ndarray],
|
|
speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
|
):
|
|
"""Inference
|
|
|
|
Args:
|
|
speech: Input speech data
|
|
Returns:
|
|
diarization results
|
|
|
|
"""
|
|
# Input as audio signal
|
|
if isinstance(speech, np.ndarray):
|
|
speech = torch.tensor(speech)
|
|
|
|
if self.frontend is not None:
|
|
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
|
feats = to_device(feats, device=self.device)
|
|
feats_len = feats_len.int()
|
|
self.diar_model.frontend = None
|
|
else:
|
|
feats = speech
|
|
feats_len = speech_lengths
|
|
batch = {"speech": feats, "speech_lengths": feats_len}
|
|
batch = to_device(batch, device=self.device)
|
|
results = self.diar_model.estimate_sequential(**batch)
|
|
|
|
return results
|
|
|
|
|
|
class Speech2DiarizationSOND:
|
|
"""Speech2Xvector class
|
|
|
|
Examples:
|
|
>>> import soundfile
|
|
>>> import numpy as np
|
|
>>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
|
|
>>> profile = np.load("profiles.npy")
|
|
>>> audio, rate = soundfile.read("speech.wav")
|
|
>>> speech2diar(audio, profile)
|
|
{"spk1": [(int, int), ...], ...}
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
diar_train_config: Union[Path, str] = None,
|
|
diar_model_file: Union[Path, str] = None,
|
|
device: Union[str, torch.device] = "cpu",
|
|
batch_size: int = 1,
|
|
dtype: str = "float32",
|
|
streaming: bool = False,
|
|
smooth_size: int = 83,
|
|
dur_threshold: float = 10,
|
|
):
|
|
|
|
# TODO: 1. Build Diarization model
|
|
diar_model, diar_train_args = build_model_from_file(
|
|
config_file=diar_train_config,
|
|
model_file=diar_model_file,
|
|
device=device,
|
|
task_name="diar",
|
|
mode="sond",
|
|
)
|
|
logging.info("diar_model: {}".format(diar_model))
|
|
logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
|
|
logging.info("diar_train_args: {}".format(diar_train_args))
|
|
diar_model.to(dtype=getattr(torch, dtype)).eval()
|
|
|
|
self.diar_model = diar_model
|
|
self.diar_train_args = diar_train_args
|
|
self.token_list = diar_train_args.token_list
|
|
self.smooth_size = smooth_size
|
|
self.dur_threshold = dur_threshold
|
|
self.device = device
|
|
self.dtype = dtype
|
|
|
|
def smooth_multi_labels(self, multi_label):
|
|
multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
|
|
return multi_label
|
|
|
|
@staticmethod
|
|
def calc_spk_turns(label_arr, spk_list):
|
|
turn_list = []
|
|
length = label_arr.shape[0]
|
|
n_spk = label_arr.shape[1]
|
|
for k in range(n_spk):
|
|
if spk_list[k] == "None":
|
|
continue
|
|
in_utt = False
|
|
start = 0
|
|
for i in range(length):
|
|
if label_arr[i, k] == 1 and in_utt is False:
|
|
start = i
|
|
in_utt = True
|
|
if label_arr[i, k] == 0 and in_utt is True:
|
|
turn_list.append([spk_list[k], start, i - start])
|
|
in_utt = False
|
|
if in_utt:
|
|
turn_list.append([spk_list[k], start, length - start])
|
|
return turn_list
|
|
|
|
@staticmethod
|
|
def seq2arr(seq, vec_dim=8):
|
|
def int2vec(x, vec_dim=8, dtype=np.int32):
|
|
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
|
# little-endian order: lower bit first
|
|
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
|
|
|
# process oov
|
|
seq = np.array([int(x) for x in seq])
|
|
new_seq = []
|
|
for i, x in enumerate(seq):
|
|
if x < 2 ** vec_dim:
|
|
new_seq.append(x)
|
|
else:
|
|
idx_list = np.where(seq < 2 ** vec_dim)[0]
|
|
if len(idx_list) > 0:
|
|
idx = np.abs(idx_list - i).argmin()
|
|
new_seq.append(seq[idx_list[idx]])
|
|
else:
|
|
new_seq.append(0)
|
|
return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
|
|
|
|
def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
|
|
logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
|
|
# upsampling outputs to match inputs
|
|
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
|
|
logits_idx = F.upsample(
|
|
logits_idx.unsqueeze(1).float(),
|
|
size=(ut,),
|
|
mode="nearest",
|
|
).squeeze(1).long()
|
|
logits_idx = logits_idx[0].tolist()
|
|
pse_labels = [self.token_list[x] for x in logits_idx]
|
|
if output_format == "pse_labels":
|
|
return pse_labels, None
|
|
|
|
multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
|
|
multi_labels = self.smooth_multi_labels(multi_labels)
|
|
if output_format == "binary_labels":
|
|
return multi_labels, None
|
|
|
|
spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
|
|
spk_turns = self.calc_spk_turns(multi_labels, spk_list)
|
|
results = OrderedDict()
|
|
for spk, st, dur in spk_turns:
|
|
if spk not in results:
|
|
results[spk] = []
|
|
if dur > self.dur_threshold:
|
|
results[spk].append((st, st + dur))
|
|
|
|
# sort segments in start time ascending
|
|
for spk in results:
|
|
results[spk] = sorted(results[spk], key=lambda x: x[0])
|
|
|
|
return results, pse_labels
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
speech: Union[torch.Tensor, np.ndarray],
|
|
profile: Union[torch.Tensor, np.ndarray],
|
|
output_format: str = "speaker_turn"
|
|
):
|
|
"""Inference
|
|
|
|
Args:
|
|
speech: Input speech data
|
|
profile: Speaker profiles
|
|
Returns:
|
|
diarization results for each speaker
|
|
|
|
"""
|
|
# Input as audio signal
|
|
if isinstance(speech, np.ndarray):
|
|
speech = torch.tensor(speech)
|
|
if isinstance(profile, np.ndarray):
|
|
profile = torch.tensor(profile)
|
|
|
|
# data: (Nsamples,) -> (1, Nsamples)
|
|
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
|
|
# lengths: (1,)
|
|
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
|
profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
|
|
batch = {"speech": speech, "speech_lengths": speech_lengths,
|
|
"profile": profile, "profile_lengths": profile_lengths}
|
|
# a. To device
|
|
batch = to_device(batch, device=self.device)
|
|
|
|
logits = self.diar_model.prediction_forward(**batch)
|
|
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
|
|
|
|
return results, pse_labels
|