update lr and bias_grad_times

This commit is contained in:
shixian.shi 2023-05-05 16:14:20 +08:00
parent fa7297855d
commit 653fffdf29
3 changed files with 22 additions and 1 deletions

View File

@ -31,6 +31,6 @@ if __name__ == '__main__':
params.dataset_type = "large" # finetune contextual paraformer模型只能使用large dataset
params.batch_bins = 200000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 20 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
params.lr = 0.0002 # 设置学习率
modelscope_finetune(params)

View File

@ -548,6 +548,12 @@ class AbsTask(ABC):
default=1,
help="The number of gradient accumulation",
)
group.add_argument(
"--bias_grad_times",
type=float,
default=1.0,
help="To scale the gradient of contextual related params",
)
group.add_argument(
"--no_forward_run",
type=str2bool,

View File

@ -3,6 +3,7 @@
"""Trainer module."""
import argparse
from audioop import bias
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
@ -95,6 +96,7 @@ class TrainerOptions:
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
batch_interval: int
bias_grad_times: float
class Trainer:
"""Trainer having a optimizer.
@ -546,8 +548,11 @@ class Trainer:
no_forward_run = options.no_forward_run
ngpu = options.ngpu
use_wandb = options.use_wandb
bias_grad_times = options.bias_grad_times
distributed = distributed_option.distributed
if bias_grad_times != 1.0:
logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
if log_interval is None:
try:
log_interval = max(len(iterator) // 20, 10)
@ -690,6 +695,16 @@ class Trainer:
scale_factor=0.55,
)
# for contextual training
if bias_grad_times != 1.0:
# contextual related parameter names
cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
for name, param in model.named_parameters():
for cr_pname in cr_pnames:
if cr_pname in name:
param.grad *= bias_grad_times
continue
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),