diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py index b12bdeda0..4f2e28cac 100644 --- a/funasr/train/trainer.py +++ b/funasr/train/trainer.py @@ -39,7 +39,7 @@ from funasr.torch_utils.add_gradient_noise import add_gradient_noise from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.recursive_op import recursive_average from funasr.torch_utils.set_all_random_seed import set_all_random_seed -from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.models.base_model import FunASRModel from funasr.train.distributed_utils import DistributedOption from funasr.train.reporter import Reporter from funasr.train.reporter import SubReporter @@ -165,7 +165,7 @@ class Trainer: @classmethod def run( cls, - model: AbsESPnetModel, + model: FunASRModel, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], train_iter_factory: AbsIterFactory,