diff --git a/mmtrack/apis/train.py b/mmtrack/apis/train.py index b0668f88d..ad80be097 100644 --- a/mmtrack/apis/train.py +++ b/mmtrack/apis/train.py @@ -12,7 +12,7 @@ from mmtrack.core import DistEvalHook, EvalHook from mmtrack.datasets import build_dataloader -from mmtrack.utils import build_dp, gbuild_ddp, get_root_logger +from mmtrack.utils import build_dp, build_ddp, get_root_logger def init_random_seed(seed=None, device='cuda'): diff --git a/mmtrack/utils/util_distribution.py b/mmtrack/utils/util_distribution.py index 6e8de4dcd..3ca4cf751 100644 --- a/mmtrack/utils/util_distribution.py +++ b/mmtrack/utils/util_distribution.py @@ -9,6 +9,7 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs): """build DataParallel module by device type. + if device is cuda, return a MMDataParallel model; if device is npu, return a NPUDataParallel model. Args: @@ -43,7 +44,7 @@ def build_ddp(model, device='cuda', *args, **kwargs): .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. DistributedDataParallel.html """ - assert device in ['cuda', 'mlu', 'npu'], 'Only available for cuda or npu devices.' + assert device in ['cuda', 'npu'], 'Only available for cuda or npu devices.' if device == 'npu': from mmcv.device.npu import NPUDistributedDataParallel torch.npu.set_compile_mode(jit_compile=False) diff --git a/tools/train.py b/tools/train.py index 8e18ec4a1..754139f9b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -17,7 +17,7 @@ from mmtrack.apis import init_random_seed from mmtrack.core import setup_multi_processes from mmtrack.datasets import build_dataset -from mmtrack.utils import collect_env, get_root_logger, get_device +from mmtrack.utils import collect_env, get_device, get_root_logger def parse_args(): @@ -176,7 +176,7 @@ def main(): logger.info(f'Set random seed to {cfg.seed}, ' f'deterministic: {deterministic}') - cfg.device = get_device() + cfg.device = get_device() if cfg.get('device', None) is None else cfg.device set_random_seed(cfg.seed, deterministic=deterministic) meta['seed'] = cfg.seed