diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py index e4d6682b0..34c7cf949 100644 --- a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py +++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py @@ -31,6 +31,6 @@ if __name__ == '__main__': params.dataset_type = "large" # finetune contextual paraformer模型只能使用large dataset params.batch_bins = 200000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, params.max_epoch = 20 # 最大训练轮数 - params.lr = 0.00005 # 设置学习率 + params.lr = 0.0002 # 设置学习率 modelscope_finetune(params) \ No newline at end of file diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 3d2004c2d..31057f93d 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -548,6 +548,12 @@ class AbsTask(ABC): default=1, help="The number of gradient accumulation", ) + group.add_argument( + "--bias_grad_times", + type=float, + default=1.0, + help="To scale the gradient of contextual related params", + ) group.add_argument( "--no_forward_run", type=str2bool, diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index 7c187e999..405268a28 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -3,6 +3,7 @@ """Trainer module.""" import argparse +from audioop import bias from contextlib import contextmanager import dataclasses from dataclasses import is_dataclass @@ -95,6 +96,7 @@ class TrainerOptions: use_pai: bool oss_bucket: Union[oss2.Bucket, None] batch_interval: int + bias_grad_times: float class Trainer: """Trainer having a optimizer. @@ -546,8 +548,11 @@ class Trainer: no_forward_run = options.no_forward_run ngpu = options.ngpu use_wandb = options.use_wandb + bias_grad_times = options.bias_grad_times distributed = distributed_option.distributed + if bias_grad_times != 1.0: + logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times)) if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) @@ -690,6 +695,16 @@ class Trainer: scale_factor=0.55, ) + # for contextual training + if bias_grad_times != 1.0: + # contextual related parameter names + cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"] + for name, param in model.named_parameters(): + for cr_pname in cr_pnames: + if cr_pname in name: + param.grad *= bias_grad_times + continue + # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(),