Skip to content

Commit

Permalink
Support mmtrack with NPU backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
luomaoling committed Apr 24, 2023
1 parent fde36ee commit dd3b605
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mmtrack/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mmtrack.core import DistEvalHook, EvalHook
from mmtrack.datasets import build_dataloader
from mmtrack.utils import build_dp, gbuild_ddp, get_root_logger
from mmtrack.utils import build_dp, build_ddp, get_root_logger


def init_random_seed(seed=None, device='cuda'):
Expand Down
3 changes: 2 additions & 1 deletion mmtrack/utils/util_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel model; if device is npu,
return a NPUDataParallel model.
Args:
Expand Down Expand Up @@ -43,7 +44,7 @@ def build_ddp(model, device='cuda', *args, **kwargs):
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'mlu', 'npu'], 'Only available for cuda or npu devices.'
assert device in ['cuda', 'npu'], 'Only available for cuda or npu devices.'
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
torch.npu.set_compile_mode(jit_compile=False)
Expand Down
4 changes: 2 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mmtrack.apis import init_random_seed
from mmtrack.core import setup_multi_processes
from mmtrack.datasets import build_dataset
from mmtrack.utils import collect_env, get_root_logger, get_device
from mmtrack.utils import collect_env, get_device, get_root_logger


def parse_args():
Expand Down Expand Up @@ -176,7 +176,7 @@ def main():
logger.info(f'Set random seed to {cfg.seed}, '
f'deterministic: {deterministic}')

cfg.device = get_device()
cfg.device = get_device() if cfg.get('device', None) is None else cfg.device

set_random_seed(cfg.seed, deterministic=deterministic)
meta['seed'] = cfg.seed
Expand Down

0 comments on commit dd3b605

Please sign in to comment.