mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
d43f77408b
commit
be26169447
@ -376,6 +376,7 @@ class OpenAIDatasetMultiTurn(torch.utils.data.Dataset):
|
||||
target_ids = self.tokenizer.encode(target_out)
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
fbank.append(speech)
|
||||
fbank_mask += fbank_mask_i
|
||||
fbank_beg.append(fbank_beg_i)
|
||||
|
||||
|
||||
@ -10,36 +10,6 @@ import torch.optim
|
||||
import pdb
|
||||
|
||||
|
||||
def filter_state_dict(
|
||||
dst_state: Dict[str, Union[float, torch.Tensor]],
|
||||
src_state: Dict[str, Union[float, torch.Tensor]],
|
||||
):
|
||||
"""Filter name, size mismatch instances between dicts.
|
||||
|
||||
Args:
|
||||
dst_state: reference state dict for filtering
|
||||
src_state: target state dict for filtering
|
||||
|
||||
"""
|
||||
match_state = {}
|
||||
for key, value in src_state.items():
|
||||
if key in dst_state and (dst_state[key].size() == src_state[key].size()):
|
||||
match_state[key] = value
|
||||
else:
|
||||
if key not in dst_state:
|
||||
logging.warning(
|
||||
f"Filter out {key} from pretrained dict"
|
||||
+ " because of name not found in target dict"
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Filter out {key} from pretrained dict"
|
||||
+ " because of size mismatch"
|
||||
+ f"({dst_state[key].size()}-{src_state[key].size()})"
|
||||
)
|
||||
return match_state
|
||||
|
||||
|
||||
def load_pretrained_model(
|
||||
path: str,
|
||||
model: torch.nn.Module,
|
||||
@ -62,7 +32,7 @@ def load_pretrained_model(
|
||||
obj = model
|
||||
dst_state = obj.state_dict()
|
||||
|
||||
print(f"ckpt: {path}")
|
||||
logging.info(f"ckpt: {path}")
|
||||
|
||||
if oss_bucket is None:
|
||||
src_state = torch.load(path, map_location=map_location)
|
||||
@ -77,9 +47,20 @@ def load_pretrained_model(
|
||||
if isinstance(scope_map, str):
|
||||
scope_map = scope_map.split(",")
|
||||
scope_map += ["module.", "None"]
|
||||
logging.info(f"scope_map: {scope_map}")
|
||||
|
||||
if excludes is not None:
|
||||
if isinstance(excludes, str):
|
||||
excludes = excludes.split(",")
|
||||
logging.info(f"excludes: {excludes}")
|
||||
|
||||
for k in dst_state.keys():
|
||||
|
||||
for k_ex in excludes:
|
||||
if k.startswith(k_ex):
|
||||
logging.info(f"key: {{k}} matching: {k_ex}, excluded")
|
||||
continue
|
||||
|
||||
k_src = k
|
||||
|
||||
if scope_map is not None:
|
||||
@ -92,25 +73,25 @@ def load_pretrained_model(
|
||||
if dst_prefix == "" and (src_prefix + k) in src_state.keys():
|
||||
k_src = src_prefix + k
|
||||
if not k_src.startswith("module."):
|
||||
print(f"init param, map: {k} from {k_src} in ckpt")
|
||||
logging.info(f"init param, map: {k} from {k_src} in ckpt")
|
||||
elif (
|
||||
k.startswith(dst_prefix)
|
||||
and k.replace(dst_prefix, src_prefix, 1) in src_state.keys()
|
||||
):
|
||||
k_src = k.replace(dst_prefix, src_prefix, 1)
|
||||
if not k_src.startswith("module."):
|
||||
print(f"init param, map: {k} from {k_src} in ckpt")
|
||||
logging.info(f"init param, map: {k} from {k_src} in ckpt")
|
||||
|
||||
if k_src in src_state.keys():
|
||||
if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
|
||||
print(
|
||||
logging.info(
|
||||
f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
|
||||
)
|
||||
else:
|
||||
dst_state[k] = src_state[k_src]
|
||||
|
||||
else:
|
||||
print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
|
||||
logging.info(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
|
||||
|
||||
flag = obj.load_state_dict(dst_state, strict=True)
|
||||
# print(flag)
|
||||
logging.info(f"Loading ckpt: {path}, status: {flag}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user