From 606f4faef2a0cd97310588754b81d8655091379c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 25 Jul 2024 11:53:34 +0800 Subject: [PATCH] bugfix python runtime --- funasr/models/llm_asr/model.py | 8 ++++++++ funasr/train_utils/trainer_ds.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index bca73bf6a..23e869761 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -929,6 +929,14 @@ class LLMASR4(nn.Module): use_cache=None, ) + if llm_conf.get("use_lora", False): + lora_conf = llm_conf.get("lora_conf", {}) + from peft import get_peft_model, LoraConfig, TaskType + + peft_config = LoraConfig(**lora_conf) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + if llm_conf.get("activation_checkpoint", False): model.gradient_checkpointing_enable() freeze = llm_conf.get("freeze", True) diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 85513a5a7..463c64db2 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -162,6 +162,7 @@ class Trainer: if isinstance(effective_save_name_excludes, str): effective_save_name_excludes = effective_save_name_excludes.split(",") self.effective_save_name_excludes = effective_save_name_excludes + self.use_lora = kwargs.get("use_lora", False) def save_checkpoint( self, @@ -342,6 +343,11 @@ class Trainer: ckpt_name = f"model.pt.ep{epoch}" else: ckpt_name = f"model.pt.ep{epoch}.{step}" + + if self.use_lora: + lora_outdir = f"{self.output_dir}/lora-{ckpt_name}" + os.makedirs(lora_outdir, exist_ok=True) + model.llm.save_pretrained(lora_outdir) filename = os.path.join(self.output_dir, ckpt_name) torch.save(state, filename)