mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix win bug
This commit is contained in:
parent
546e3a432c
commit
bb97d3ed19
@ -7,6 +7,6 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
|
||||
|
||||
wav_file = f"{model.model_path}/example/example/test.wav"
|
||||
wav_file = f"{model.model_path}/example/test.wav"
|
||||
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance")
|
||||
print(res)
|
||||
@ -183,9 +183,11 @@ class AutoModel:
|
||||
logging.info(f"Loading pretrained params from {init_param}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
init_param=init_param,
|
||||
path=init_param,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
scope_map=kwargs.get("scope_map", None),
|
||||
excludes=kwargs.get("excludes", None),
|
||||
)
|
||||
|
||||
return model, kwargs
|
||||
|
||||
@ -96,9 +96,11 @@ def main(**kwargs):
|
||||
logging.info(f"Loading pretrained params from {p}")
|
||||
load_pretrained_model(
|
||||
model=model,
|
||||
init_param=p,
|
||||
path=p,
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
scope_map=kwargs.get("scope_map", None),
|
||||
excludes=kwargs.get("excludes", None),
|
||||
)
|
||||
else:
|
||||
initialize(model, kwargs.get("init", "kaiming_normal"))
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "AudioDataset")
|
||||
@ -55,7 +55,7 @@ class AudioDataset(torch.utils.data.Dataset):
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio(source, fs=self.fs)
|
||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
|
||||
|
||||
@ -38,13 +38,51 @@ def filter_state_dict(
|
||||
)
|
||||
return match_state
|
||||
|
||||
def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None):
|
||||
"""Compute the union of the current variables and checkpoint variables."""
|
||||
import collections
|
||||
import re
|
||||
|
||||
# current model variables
|
||||
name_to_variable = collections.OrderedDict()
|
||||
for name, var in dst_state.items():
|
||||
name_to_variable[name] = var
|
||||
|
||||
scope_map_num = 0
|
||||
if scope_map is not None:
|
||||
scope_map = scope_map.split(",")
|
||||
scope_map_num = len(scope_map) // 2
|
||||
for scope_map_idx in range(scope_map_num):
|
||||
scope_map_id = scope_map_idx * 2
|
||||
logging.info('assignment_map from scope {} to {}'.format(scope_map[scope_map_id], scope_map[scope_map_id+1]))
|
||||
|
||||
assignment_map = {}
|
||||
for name, var in src_state.items():
|
||||
|
||||
if scope_map:
|
||||
for scope_map_idx in range(scope_map_num):
|
||||
scope_map_id = scope_map_idx * 2
|
||||
try:
|
||||
idx = name.index(scope_map[scope_map_id])
|
||||
new_name = scope_map[scope_map_id+1] + name[idx + len(scope_map[scope_map_id]):]
|
||||
if new_name in name_to_variable:
|
||||
assignment_map[name] = var
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
if name in name_to_variable:
|
||||
assignment_map[name] = var
|
||||
|
||||
return assignment_map
|
||||
|
||||
def load_pretrained_model(
|
||||
init_param: str,
|
||||
path: str,
|
||||
model: torch.nn.Module,
|
||||
ignore_init_mismatch: bool,
|
||||
map_location: str = "cpu",
|
||||
oss_bucket=None,
|
||||
scope_map=None,
|
||||
excludes=None,
|
||||
):
|
||||
"""Load a model state and set it to the model.
|
||||
|
||||
@ -52,53 +90,10 @@ def load_pretrained_model(
|
||||
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
|
||||
|
||||
Examples:
|
||||
>>> load_pretrained_model("somewhere/model.pb", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model)
|
||||
>>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model)
|
||||
>>> load_pretrained_model(
|
||||
... "somewhere/model.pb:decoder:decoder:decoder.embed", model
|
||||
... )
|
||||
>>> load_pretrained_model("somewhere/decoder.pb::decoder", model)
|
||||
"""
|
||||
sps = init_param.split(":", 4)
|
||||
if len(sps) == 4:
|
||||
path, src_key, dst_key, excludes = sps
|
||||
elif len(sps) == 3:
|
||||
path, src_key, dst_key = sps
|
||||
excludes = None
|
||||
elif len(sps) == 2:
|
||||
path, src_key = sps
|
||||
dst_key, excludes = None, None
|
||||
else:
|
||||
(path,) = sps
|
||||
src_key, dst_key, excludes = None, None, None
|
||||
if src_key == "":
|
||||
src_key = None
|
||||
if dst_key == "":
|
||||
dst_key = None
|
||||
|
||||
if dst_key is None:
|
||||
"""
|
||||
|
||||
obj = model
|
||||
else:
|
||||
|
||||
def get_attr(obj: Any, key: str):
|
||||
"""Get an nested attribute.
|
||||
|
||||
>>> class A(torch.nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.linear = torch.nn.Linear(10, 10)
|
||||
>>> a = A()
|
||||
>>> assert A.linear.weight is get_attr(A, 'linear.weight')
|
||||
|
||||
"""
|
||||
if key.strip() == "":
|
||||
return obj
|
||||
for k in key.split("."):
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
obj = get_attr(model, dst_key)
|
||||
|
||||
if oss_bucket is None:
|
||||
src_state = torch.load(path, map_location=map_location)
|
||||
@ -106,23 +101,18 @@ def load_pretrained_model(
|
||||
buffer = BytesIO(oss_bucket.get_object(path).read())
|
||||
src_state = torch.load(buffer, map_location=map_location)
|
||||
src_state = src_state["model"] if "model" in src_state else src_state
|
||||
|
||||
if excludes is not None:
|
||||
for e in excludes.split(","):
|
||||
src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
|
||||
|
||||
if src_key is not None:
|
||||
src_state = {
|
||||
k[len(src_key) + 1 :]: v
|
||||
for k, v in src_state.items()
|
||||
if k.startswith(src_key)
|
||||
}
|
||||
|
||||
dst_state = obj.state_dict()
|
||||
src_state = assigment_scope_map(dst_state, src_state, scope_map)
|
||||
|
||||
if ignore_init_mismatch:
|
||||
src_state = filter_state_dict(dst_state, src_state)
|
||||
|
||||
logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
|
||||
logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
|
||||
dst_state.update(src_state)
|
||||
# dst_state.update(src_state)
|
||||
obj.load_state_dict(dst_state)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user