diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py index 91d00aabc..a41641e5a 100644 --- a/examples/industrial_data_pretraining/emotion2vec/demo.py +++ b/examples/industrial_data_pretraining/emotion2vec/demo.py @@ -7,6 +7,6 @@ from funasr import AutoModel model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1") -wav_file = f"{model.model_path}/example/example/test.wav" +wav_file = f"{model.model_path}/example/test.wav" res = model.generate(wav_file, output_dir="./outputs", granularity="utterance") print(res) \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 580cca8d4..ffb56a50e 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -183,9 +183,11 @@ class AutoModel: logging.info(f"Loading pretrained params from {init_param}") load_pretrained_model( model=model, - init_param=init_param, + path=init_param, ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False), oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", None), + excludes=kwargs.get("excludes", None), ) return model, kwargs diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 0881cb2da..ef0d205b5 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -96,9 +96,11 @@ def main(**kwargs): logging.info(f"Loading pretrained params from {p}") load_pretrained_model( model=model, - init_param=p, + path=p, ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", None), + excludes=kwargs.get("excludes", None), ) else: initialize(model, kwargs.get("init", "kaiming_normal")) diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py index edf127f57..5af33fca4 100644 --- a/funasr/datasets/audio_datasets/datasets.py +++ b/funasr/datasets/audio_datasets/datasets.py @@ -1,7 +1,7 @@ import torch from funasr.register import tables -from funasr.utils.load_utils import extract_fbank +from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video @tables.register("dataset_classes", "AudioDataset") @@ -55,7 +55,7 @@ class AudioDataset(torch.utils.data.Dataset): # import pdb; # pdb.set_trace() source = item["source"] - data_src = load_audio(source, fs=self.fs) + data_src = load_audio_text_image_video(source, fs=self.fs) if self.preprocessor_speech: data_src = self.preprocessor_speech(data_src) speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d] diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index ef9d93a14..16feabd70 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -10,119 +10,109 @@ import torch.optim def filter_state_dict( - dst_state: Dict[str, Union[float, torch.Tensor]], - src_state: Dict[str, Union[float, torch.Tensor]], + dst_state: Dict[str, Union[float, torch.Tensor]], + src_state: Dict[str, Union[float, torch.Tensor]], ): - """Filter name, size mismatch instances between dicts. + """Filter name, size mismatch instances between dicts. - Args: - dst_state: reference state dict for filtering - src_state: target state dict for filtering + Args: + dst_state: reference state dict for filtering + src_state: target state dict for filtering - """ - match_state = {} - for key, value in src_state.items(): - if key in dst_state and (dst_state[key].size() == src_state[key].size()): - match_state[key] = value - else: - if key not in dst_state: - logging.warning( - f"Filter out {key} from pretrained dict" - + " because of name not found in target dict" - ) - else: - logging.warning( - f"Filter out {key} from pretrained dict" - + " because of size mismatch" - + f"({dst_state[key].size()}-{src_state[key].size()})" - ) - return match_state + """ + match_state = {} + for key, value in src_state.items(): + if key in dst_state and (dst_state[key].size() == src_state[key].size()): + match_state[key] = value + else: + if key not in dst_state: + logging.warning( + f"Filter out {key} from pretrained dict" + + " because of name not found in target dict" + ) + else: + logging.warning( + f"Filter out {key} from pretrained dict" + + " because of size mismatch" + + f"({dst_state[key].size()}-{src_state[key].size()})" + ) + return match_state +def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None): + """Compute the union of the current variables and checkpoint variables.""" + import collections + import re + + # current model variables + name_to_variable = collections.OrderedDict() + for name, var in dst_state.items(): + name_to_variable[name] = var + + scope_map_num = 0 + if scope_map is not None: + scope_map = scope_map.split(",") + scope_map_num = len(scope_map) // 2 + for scope_map_idx in range(scope_map_num): + scope_map_id = scope_map_idx * 2 + logging.info('assignment_map from scope {} to {}'.format(scope_map[scope_map_id], scope_map[scope_map_id+1])) + + assignment_map = {} + for name, var in src_state.items(): + + if scope_map: + for scope_map_idx in range(scope_map_num): + scope_map_id = scope_map_idx * 2 + try: + idx = name.index(scope_map[scope_map_id]) + new_name = scope_map[scope_map_id+1] + name[idx + len(scope_map[scope_map_id]):] + if new_name in name_to_variable: + assignment_map[name] = var + except: + continue + else: + if name in name_to_variable: + assignment_map[name] = var + + return assignment_map def load_pretrained_model( - init_param: str, - model: torch.nn.Module, - ignore_init_mismatch: bool, - map_location: str = "cpu", - oss_bucket=None, + path: str, + model: torch.nn.Module, + ignore_init_mismatch: bool, + map_location: str = "cpu", + oss_bucket=None, + scope_map=None, + excludes=None, ): - """Load a model state and set it to the model. + """Load a model state and set it to the model. - Args: - init_param: ::: + Args: + init_param: ::: - Examples: - >>> load_pretrained_model("somewhere/model.pb", model) - >>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model) - >>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model) - >>> load_pretrained_model( - ... "somewhere/model.pb:decoder:decoder:decoder.embed", model - ... ) - >>> load_pretrained_model("somewhere/decoder.pb::decoder", model) - """ - sps = init_param.split(":", 4) - if len(sps) == 4: - path, src_key, dst_key, excludes = sps - elif len(sps) == 3: - path, src_key, dst_key = sps - excludes = None - elif len(sps) == 2: - path, src_key = sps - dst_key, excludes = None, None - else: - (path,) = sps - src_key, dst_key, excludes = None, None, None - if src_key == "": - src_key = None - if dst_key == "": - dst_key = None + Examples: - if dst_key is None: - obj = model - else: - - def get_attr(obj: Any, key: str): - """Get an nested attribute. - - >>> class A(torch.nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.linear = torch.nn.Linear(10, 10) - >>> a = A() - >>> assert A.linear.weight is get_attr(A, 'linear.weight') - - """ - if key.strip() == "": - return obj - for k in key.split("."): - obj = getattr(obj, k) - return obj - - obj = get_attr(model, dst_key) - - if oss_bucket is None: - src_state = torch.load(path, map_location=map_location) - else: - buffer = BytesIO(oss_bucket.get_object(path).read()) - src_state = torch.load(buffer, map_location=map_location) - src_state = src_state["model"] if "model" in src_state else src_state - if excludes is not None: - for e in excludes.split(","): - src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} - - if src_key is not None: - src_state = { - k[len(src_key) + 1 :]: v - for k, v in src_state.items() - if k.startswith(src_key) - } - - dst_state = obj.state_dict() - if ignore_init_mismatch: - src_state = filter_state_dict(dst_state, src_state) - - logging.debug("Loaded src_state keys: {}".format(src_state.keys())) - logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) - dst_state.update(src_state) - obj.load_state_dict(dst_state) - \ No newline at end of file + """ + + obj = model + + if oss_bucket is None: + src_state = torch.load(path, map_location=map_location) + else: + buffer = BytesIO(oss_bucket.get_object(path).read()) + src_state = torch.load(buffer, map_location=map_location) + src_state = src_state["model"] if "model" in src_state else src_state + + if excludes is not None: + for e in excludes.split(","): + src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} + + dst_state = obj.state_dict() + src_state = assigment_scope_map(dst_state, src_state, scope_map) + + if ignore_init_mismatch: + src_state = filter_state_dict(dst_state, src_state) + + logging.debug("Loaded src_state keys: {}".format(src_state.keys())) + logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) + # dst_state.update(src_state) + obj.load_state_dict(dst_state) \ No newline at end of file