mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
bugfix python runtime
This commit is contained in:
parent
4fe50f431b
commit
606f4faef2
@ -929,6 +929,14 @@ class LLMASR4(nn.Module):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if llm_conf.get("use_lora", False):
|
||||||
|
lora_conf = llm_conf.get("lora_conf", {})
|
||||||
|
from peft import get_peft_model, LoraConfig, TaskType
|
||||||
|
|
||||||
|
peft_config = LoraConfig(**lora_conf)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
if llm_conf.get("activation_checkpoint", False):
|
if llm_conf.get("activation_checkpoint", False):
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
freeze = llm_conf.get("freeze", True)
|
freeze = llm_conf.get("freeze", True)
|
||||||
|
|||||||
@ -162,6 +162,7 @@ class Trainer:
|
|||||||
if isinstance(effective_save_name_excludes, str):
|
if isinstance(effective_save_name_excludes, str):
|
||||||
effective_save_name_excludes = effective_save_name_excludes.split(",")
|
effective_save_name_excludes = effective_save_name_excludes.split(",")
|
||||||
self.effective_save_name_excludes = effective_save_name_excludes
|
self.effective_save_name_excludes = effective_save_name_excludes
|
||||||
|
self.use_lora = kwargs.get("use_lora", False)
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
self,
|
self,
|
||||||
@ -342,6 +343,11 @@ class Trainer:
|
|||||||
ckpt_name = f"model.pt.ep{epoch}"
|
ckpt_name = f"model.pt.ep{epoch}"
|
||||||
else:
|
else:
|
||||||
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
||||||
|
|
||||||
|
if self.use_lora:
|
||||||
|
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
|
||||||
|
os.makedirs(lora_outdir, exist_ok=True)
|
||||||
|
model.llm.save_pretrained(lora_outdir)
|
||||||
filename = os.path.join(self.output_dir, ckpt_name)
|
filename = os.path.join(self.output_dir, ckpt_name)
|
||||||
torch.save(state, filename)
|
torch.save(state, filename)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user