mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
007a8dc289
commit
6569b95025
@ -1150,7 +1150,6 @@ class AbsTask(ABC):
|
||||
def main_worker(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
|
||||
args.ngpu = 0
|
||||
# 0. Init distributed process
|
||||
distributed_option = build_dataclass(DistributedOption, args)
|
||||
# Setting distributed_option.dist_rank, etc.
|
||||
@ -1253,13 +1252,9 @@ class AbsTask(ABC):
|
||||
raise RuntimeError(
|
||||
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
|
||||
)
|
||||
#model = model.to(
|
||||
# dtype=getattr(torch, args.train_dtype),
|
||||
# device="cuda" if args.ngpu > 0 else "cpu",
|
||||
#)
|
||||
model = model.to(
|
||||
dtype=getattr(torch, args.train_dtype),
|
||||
device="cpu",
|
||||
device="cuda" if args.ngpu > 0 else "cpu",
|
||||
)
|
||||
for t in args.freeze_param:
|
||||
for k, p in model.named_parameters():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user