mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update lr and bias_grad_times
This commit is contained in:
parent
fa7297855d
commit
653fffdf29
@ -31,6 +31,6 @@ if __name__ == '__main__':
|
||||
params.dataset_type = "large" # finetune contextual paraformer模型只能使用large dataset
|
||||
params.batch_bins = 200000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
||||
params.max_epoch = 20 # 最大训练轮数
|
||||
params.lr = 0.00005 # 设置学习率
|
||||
params.lr = 0.0002 # 设置学习率
|
||||
|
||||
modelscope_finetune(params)
|
||||
@ -548,6 +548,12 @@ class AbsTask(ABC):
|
||||
default=1,
|
||||
help="The number of gradient accumulation",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bias_grad_times",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="To scale the gradient of contextual related params",
|
||||
)
|
||||
group.add_argument(
|
||||
"--no_forward_run",
|
||||
type=str2bool,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Trainer module."""
|
||||
import argparse
|
||||
from audioop import bias
|
||||
from contextlib import contextmanager
|
||||
import dataclasses
|
||||
from dataclasses import is_dataclass
|
||||
@ -95,6 +96,7 @@ class TrainerOptions:
|
||||
use_pai: bool
|
||||
oss_bucket: Union[oss2.Bucket, None]
|
||||
batch_interval: int
|
||||
bias_grad_times: float
|
||||
|
||||
class Trainer:
|
||||
"""Trainer having a optimizer.
|
||||
@ -546,8 +548,11 @@ class Trainer:
|
||||
no_forward_run = options.no_forward_run
|
||||
ngpu = options.ngpu
|
||||
use_wandb = options.use_wandb
|
||||
bias_grad_times = options.bias_grad_times
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
if bias_grad_times != 1.0:
|
||||
logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
|
||||
if log_interval is None:
|
||||
try:
|
||||
log_interval = max(len(iterator) // 20, 10)
|
||||
@ -690,6 +695,16 @@ class Trainer:
|
||||
scale_factor=0.55,
|
||||
)
|
||||
|
||||
# for contextual training
|
||||
if bias_grad_times != 1.0:
|
||||
# contextual related parameter names
|
||||
cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
|
||||
for name, param in model.named_parameters():
|
||||
for cr_pname in cr_pnames:
|
||||
if cr_pname in name:
|
||||
param.grad *= bias_grad_times
|
||||
continue
|
||||
|
||||
# compute the gradient norm to check if it is normal or not
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user