Skip to content

MasterSkepticista/detr

Repository files navigation

DETR: End-to-End Object Detection with Transformers.

This is a minimal implementation of DETR using jax and flax.

DETR Architecture JAX logo
Flax logo

Updates:

  • Supports Sinkhorn solver based on latest OTT package (50-100% faster training for roughly same final AP).
  • Parallel bipartite matching for all auxiliary outputs (up to 30% faster training using Hungarian matcher).
  • Uses optax API.
  • Bug fixes to match official DETR implementation.
  • Supports BigTransfer (BiT-S) ResNet-50 backbone.

Getting Started

  • Setup:

    $> git clone https://github.com/MasterSkepticista/detr.git && cd detr
    # Create a python>=3.10 venv
    $> pip install -U pip setuptools wheel
    $> pip install -r requirements.txt
  • You may need to download MS-COCO dataset in TFDS. Run the following to download and create TFRecords:

    $> python -c "import tensorflow_datasets as tfds; tfds.load('coco/2017')"
  • Download and extract instances_val2017.json from MS-COCO in the root directory of this repo (or update config.annotations_loc in the config).

Train

Set config.pretrained_backbone_configs.checkpoint_path in the common config file.

Backbone Top-1 Acc. Checkpoint
BiT-R50x1-i1k 76.8% Link
R50x1-i1k (from torchvision) 76.1% Link (created using this gist)
# Trains the default DETR-R50-1333 model.
# Roughly 3.5 days on 8x 3090s.
$> python main.py \
   --config configs/hungarian.py --workdir artifacts/`date '+%m-%d_%H%M'`

Evaluate

Checkpoints (all non-DC5 variants) using the torchvision R50 backbone:

Checkpoint GFLOPs $AP$ $AP_{50}$ $AP_{75}$ $AP_S$ $AP_M$ $AP_L$
DETR-R50-1333* 174.2 40.80 61.88 42.45 19.2 44.31 60.32
DETR-R50-640 38.5 33.14 52.89 34.00 10.54 35.10 55.53

*official DETR baseline, except that these models were trained for 300 epochs instead of 500 epochs.

  1. Download one of the pretrained checkpoints.
    # In configs/common.py (or any)
    config.init_from = ml_collections.ConfigDict()
    config.init_from.checkpoint_path = '/path/to/checkpoint'
  2. Replace config.total_epochs with config.total_steps = 0 to skip to eval.

Acknowledgements

Large parts of this codebase were motivated by scenic.

Authors' implementation in PyTorch: facebookresearch/detr.