mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * bugfix * update with main (#1631) * update seaco finetune * v1.0.24 --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * update with main (#1638) * update seaco finetune * v1.0.24 * update rwkv template --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * whisper * whisper * update style * update style --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
48 lines
1.8 KiB
Python
48 lines
1.8 KiB
Python
# ------------------------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
# ------------------------------------------------------------------------------------------
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from typing import Dict
|
|
|
|
from .layers import LoRALayer
|
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
|
|
for n, p in model.named_parameters():
|
|
if "lora_" not in n and "cif" not in n:
|
|
p.requires_grad = False
|
|
if bias == "none":
|
|
return
|
|
elif bias == "all":
|
|
for n, p in model.named_parameters():
|
|
if "bias" in n:
|
|
p.requires_grad = True
|
|
elif bias == "lora_only":
|
|
for m in model.modules():
|
|
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
|
|
m.bias.requires_grad = True
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
|
|
my_state_dict = model.state_dict()
|
|
if bias == "none":
|
|
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
|
|
elif bias == "all":
|
|
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k}
|
|
elif bias == "lora_only":
|
|
to_return = {}
|
|
for k in my_state_dict:
|
|
if "lora_" in k:
|
|
to_return[k] = my_state_dict[k]
|
|
bias_name = k.split("lora_")[0] + "bias"
|
|
if bias_name in my_state_dict:
|
|
to_return[bias_name] = my_state_dict[bias_name]
|
|
return to_return
|
|
else:
|
|
raise NotImplementedError
|