bugfix python runtime

This commit is contained in:
游雁 2024-07-25 11:53:34 +08:00
parent 4fe50f431b
commit 606f4faef2
2 changed files with 14 additions and 0 deletions

View File

@ -929,6 +929,14 @@ class LLMASR4(nn.Module):
use_cache=None,
)
if llm_conf.get("use_lora", False):
lora_conf = llm_conf.get("lora_conf", {})
from peft import get_peft_model, LoraConfig, TaskType
peft_config = LoraConfig(**lora_conf)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()
freeze = llm_conf.get("freeze", True)

View File

@ -162,6 +162,7 @@ class Trainer:
if isinstance(effective_save_name_excludes, str):
effective_save_name_excludes = effective_save_name_excludes.split(",")
self.effective_save_name_excludes = effective_save_name_excludes
self.use_lora = kwargs.get("use_lora", False)
def save_checkpoint(
self,
@ -342,6 +343,11 @@ class Trainer:
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f"model.pt.ep{epoch}.{step}"
if self.use_lora:
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
os.makedirs(lora_outdir, exist_ok=True)
model.llm.save_pretrained(lora_outdir)
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)