FunASR/funasr/models/lora/utils.py
zhifu gao 861147c730
Dev gzf exp (#1654)
* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* sense voice

* whisper

* whisper

* update style

* update style

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
2024-04-24 16:03:38 +08:00

48 lines
1.8 KiB
Python

# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
from typing import Dict
from .layers import LoRALayer
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
if "lora_" not in n and "cif" not in n:
p.requires_grad = False
if bias == "none":
return
elif bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == "none":
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
elif bias == "all":
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in my_state_dict:
if "lora_" in k:
to_return[k] = my_state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError