From 4b30f336ee7e3ca405cfa6ff96d9b3c3e936f767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Thu, 15 Jun 2023 15:03:21 +0800 Subject: [PATCH] update repo --- funasr/bin/diar_infer.py | 49 +++++--------- funasr/bin/diar_inference_launch.py | 67 +++++++------------ funasr/build_utils/build_model_from_file.py | 39 ++++++++++- .../build_utils/build_streaming_iterator.py | 5 +- 4 files changed, 82 insertions(+), 78 deletions(-) diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py index 4460e3dd1..7c41b6031 100755 --- a/funasr/bin/diar_infer.py +++ b/funasr/bin/diar_infer.py @@ -1,41 +1,28 @@ -# -*- encoding: utf-8 -*- #!/usr/bin/env python3 +# -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) -import argparse import logging import os -import sys +from collections import OrderedDict from pathlib import Path from typing import Any -from typing import List from typing import Optional -from typing import Sequence -from typing import Tuple from typing import Union -from collections import OrderedDict import numpy as np -import soundfile import torch +from scipy.ndimage import median_filter from torch.nn import functional as F from typeguard import check_argument_types -from typeguard import check_return_type -from funasr.utils.cli_utils import get_commandline_args -from funasr.tasks.diar import DiarTask -from funasr.tasks.diar import EENDOLADiarTask -from funasr.torch_utils.device_funcs import to_device -from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.utils import config_argparse -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from scipy.ndimage import median_filter -from funasr.utils.misc import statistic_model_parameters -from funasr.datasets.iterable_dataset import load_bytes from funasr.models.frontend.wav_frontend import WavFrontendMel23 +from funasr.tasks.diar import DiarTask +from funasr.build_utils.build_model_from_file import build_model_from_file +from funasr.torch_utils.device_funcs import to_device +from funasr.utils.misc import statistic_model_parameters + class Speech2DiarizationEEND: """Speech2Diarlization class @@ -61,10 +48,12 @@ class Speech2DiarizationEEND: assert check_argument_types() # 1. Build Diarization model - diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file( + diar_model, diar_train_args = build_model_from_file( config_file=diar_train_config, model_file=diar_model_file, - device=device + device=device, + task_name="diar", + mode="eend-ola", ) frontend = None if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None: @@ -177,10 +166,12 @@ class Speech2DiarizationSOND: assert check_argument_types() # TODO: 1. Build Diarization model - diar_model, diar_train_args = DiarTask.build_model_from_file( + diar_model, diar_train_args = build_model_from_file( config_file=diar_train_config, model_file=diar_model_file, - device=device + device=device, + task_name="diar", + mode="sond", ) logging.info("diar_model: {}".format(diar_model)) logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model))) @@ -248,7 +239,7 @@ class Speech2DiarizationSOND: ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio logits_idx = F.upsample( logits_idx.unsqueeze(1).float(), - size=(ut, ), + size=(ut,), mode="nearest", ).squeeze(1).long() logits_idx = logits_idx[0].tolist() @@ -268,7 +259,7 @@ class Speech2DiarizationSOND: if spk not in results: results[spk] = [] if dur > self.dur_threshold: - results[spk].append((st, st+dur)) + results[spk].append((st, st + dur)) # sort segments in start time ascending for spk in results: @@ -344,7 +335,3 @@ class Speech2DiarizationSOND: kwargs.update(**d.download_and_unpack(model_tag)) return Speech2DiarizationSOND(**kwargs) - - - - diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py index e0d900e76..820217b1a 100755 --- a/funasr/bin/diar_inference_launch.py +++ b/funasr/bin/diar_inference_launch.py @@ -1,5 +1,5 @@ +# !/usr/bin/env python3 # -*- encoding: utf-8 -*- -#!/usr/bin/env python3 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) @@ -8,47 +8,28 @@ import argparse import logging import os import sys -from typing import Union, Dict, Any - -from funasr.utils import config_argparse -from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none - -import argparse -import logging -import os -import sys -from pathlib import Path -from typing import Any from typing import List from typing import Optional from typing import Sequence from typing import Tuple from typing import Union -from collections import OrderedDict import numpy as np import soundfile import torch -from torch.nn import functional as F -from typeguard import check_argument_types -from typeguard import check_return_type from scipy.signal import medfilt -from funasr.utils.cli_utils import get_commandline_args -from funasr.tasks.diar import DiarTask -from funasr.tasks.diar import EENDOLADiarTask -from funasr.torch_utils.device_funcs import to_device +from typeguard import check_argument_types + +from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND +from funasr.datasets.iterable_dataset import load_bytes +from funasr.build_utils.build_streaming_iterator import build_streaming_iterator from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none -from scipy.ndimage import median_filter -from funasr.utils.misc import statistic_model_parameters -from funasr.datasets.iterable_dataset import load_bytes -from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND + def inference_sond( diar_train_config: str, @@ -94,7 +75,8 @@ def inference_sond( set_all_random_seed(seed) # 2a. Build speech2xvec [Optional] - if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]: + if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[ + "extract_profile"]: assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict." assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict." sv_train_config = param_dict["sv_train_config"] @@ -139,7 +121,7 @@ def inference_sond( rst = [] mid = uttid.rsplit("-", 1)[0] for key in results: - results[key] = [(x[0]/100, x[1]/100) for x in results[key]] + results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]] if out_format == "vad": for spk, segs in results.items(): rst.append("{} {}".format(spk, segs)) @@ -176,7 +158,7 @@ def inference_sond( example = [x.numpy() if isinstance(example[0], torch.Tensor) else x for x in example] speech = example[0] - logging.info("Extracting profiles for {} waveforms".format(len(example)-1)) + logging.info("Extracting profiles for {} waveforms".format(len(example) - 1)) profile = [speech2xvector.calculate_embedding(x) for x in example[1:]] profile = torch.cat(profile, dim=0) yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]} @@ -186,16 +168,15 @@ def inference_sond( raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ") else: # 3. Build data-iterator - loader = DiarTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="diar", + preprocess_args=None, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=None, - collate_fn=None, - allow_variable_data_keys=allow_variable_data_keys, - inference=True, + use_collate_fn=False, ) # 7. Start for-loop @@ -235,6 +216,7 @@ def inference_sond( return _forward + def inference_eend( diar_train_config: str, diar_model_file: str, @@ -306,16 +288,14 @@ def inference_eend( if isinstance(raw_inputs, torch.Tensor): raw_inputs = raw_inputs.numpy() data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"] - loader = EENDOLADiarTask.build_streaming_iterator( - data_path_and_name_and_type, + loader = build_streaming_iterator( + task_name="diar", + preprocess_args=None, + data_path_and_name_and_type=data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, - preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False), - collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, ) # 3. Start for-loop @@ -362,8 +342,6 @@ def inference_eend( return _forward - - def inference_launch(mode, **kwargs): if mode == "sond": return inference_sond(mode=mode, **kwargs) @@ -386,6 +364,7 @@ def inference_launch(mode, **kwargs): logging.info("Unknown decoding mode: {}".format(mode)) return None + def get_parser(): parser = config_argparse.ArgumentParser( description="Speaker Verification", diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py index 5488c1029..2eadae4e6 100644 --- a/funasr/build_utils/build_model_from_file.py +++ b/funasr/build_utils/build_model_from_file.py @@ -72,6 +72,8 @@ def build_model_from_file( model.load_state_dict(model_dict) else: model_dict = torch.load(model_file, map_location=device) + if task_name == "diar" and mode == "sond": + model_dict = fileter_model_dict(model_dict, model.state_dict()) model.load_state_dict(model_dict) if model_name_pth is not None and not os.path.exists(model_name_pth): torch.save(model_dict, model_name_pth) @@ -85,7 +87,7 @@ def convert_tf2torch( ckpt, mode, ): - assert mode == "paraformer" or mode == "uniasr" + assert mode == "paraformer" or mode == "uniasr" or mode == "sond" logging.info("start convert tf model to torch model") from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict var_dict_tf = load_tf_dict(ckpt) @@ -113,7 +115,7 @@ def convert_tf2torch( # stride_conv var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) - else: + elif mode == "paraformer": # encoder var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) @@ -126,5 +128,38 @@ def convert_tf2torch( # bias_encoder var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) var_dict_torch_update.update(var_dict_torch_update_local) + else: + if model.encoder is not None: + var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # speaker encoder + if model.speaker_encoder is not None: + var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # cd scorer + if model.cd_scorer is not None: + var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # ci scorer + if model.ci_scorer is not None: + var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # decoder + if model.decoder is not None: + var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update + +def fileter_model_dict(src_dict: dict, dest_dict: dict): + from collections import OrderedDict + new_dict = OrderedDict() + for key, value in src_dict.items(): + if key in dest_dict: + new_dict[key] = value + else: + logging.info("{} is no longer needed in this model.".format(key)) + for key, value in dest_dict.items(): + if key not in new_dict: + logging.warning("{} is missed in checkpoint.".format(key)) + return new_dict \ No newline at end of file diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py index 732fe097d..da42929f1 100644 --- a/funasr/build_utils/build_streaming_iterator.py +++ b/funasr/build_utils/build_streaming_iterator.py @@ -17,6 +17,7 @@ def build_streaming_iterator( mc: bool = False, dtype: str = np.float32, num_workers: int = 1, + use_collate_fn: bool = True, ngpu: int = 0, train: bool=False, ) -> DataLoader: @@ -30,7 +31,9 @@ def build_streaming_iterator( preprocess_fn = None # collate - if task_name in ["punc", "lm"]: + if not use_collate_fn: + collate_fn = None + elif task_name in ["punc", "lm"]: collate_fn = CommonCollateFn(int_pad_value=0) else: collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)