This commit is contained in:
speech_asr 2023-04-11 00:20:54 +08:00
parent df662541a8
commit 7f3e5bb5fb

View File

@ -39,7 +39,7 @@ from funasr.torch_utils.add_gradient_noise import add_gradient_noise
from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter from funasr.train.reporter import SubReporter
@ -165,7 +165,7 @@ class Trainer:
@classmethod @classmethod
def run( def run(
cls, cls,
model: AbsESPnetModel, model: FunASRModel,
optimizers: Sequence[torch.optim.Optimizer], optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]], schedulers: Sequence[Optional[AbsScheduler]],
train_iter_factory: AbsIterFactory, train_iter_factory: AbsIterFactory,