mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
bugfix python runtime
This commit is contained in:
parent
4fe50f431b
commit
606f4faef2
@ -929,6 +929,14 @@ class LLMASR4(nn.Module):
|
||||
use_cache=None,
|
||||
)
|
||||
|
||||
if llm_conf.get("use_lora", False):
|
||||
lora_conf = llm_conf.get("lora_conf", {})
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
|
||||
peft_config = LoraConfig(**lora_conf)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
if llm_conf.get("activation_checkpoint", False):
|
||||
model.gradient_checkpointing_enable()
|
||||
freeze = llm_conf.get("freeze", True)
|
||||
|
||||
@ -162,6 +162,7 @@ class Trainer:
|
||||
if isinstance(effective_save_name_excludes, str):
|
||||
effective_save_name_excludes = effective_save_name_excludes.split(",")
|
||||
self.effective_save_name_excludes = effective_save_name_excludes
|
||||
self.use_lora = kwargs.get("use_lora", False)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
@ -342,6 +343,11 @@ class Trainer:
|
||||
ckpt_name = f"model.pt.ep{epoch}"
|
||||
else:
|
||||
ckpt_name = f"model.pt.ep{epoch}.{step}"
|
||||
|
||||
if self.use_lora:
|
||||
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
|
||||
os.makedirs(lora_outdir, exist_ok=True)
|
||||
model.llm.save_pretrained(lora_outdir)
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
torch.save(state, filename)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user