This is a minimal implementation of DETR using jax
and flax
.
Updates:
- Supports Sinkhorn solver based on latest OTT package (50-100% faster training for roughly same final AP).
- Parallel bipartite matching for all auxiliary outputs (up to 30% faster training using Hungarian matcher).
- Uses
optax
API. - Bug fixes to match official DETR implementation.
- Supports BigTransfer (BiT-S) ResNet-50 backbone.
-
Setup:
$> git clone https://github.com/MasterSkepticista/detr.git && cd detr # Create a python>=3.10 venv $> pip install -U pip setuptools wheel $> pip install -r requirements.txt
-
You may need to download MS-COCO dataset in TFDS. Run the following to download and create TFRecords:
$> python -c "import tensorflow_datasets as tfds; tfds.load('coco/2017')"
-
Download and extract
instances_val2017.json
from MS-COCO in the root directory of this repo (or updateconfig.annotations_loc
in the config).
Set config.pretrained_backbone_configs.checkpoint_path
in the common config file.
Backbone | Top-1 Acc. | Checkpoint |
---|---|---|
BiT-R50x1-i1k | 76.8% | Link |
R50x1-i1k (from torchvision) | 76.1% | Link (created using this gist) |
# Trains the default DETR-R50-1333 model.
# Roughly 3.5 days on 8x 3090s.
$> python main.py \
--config configs/hungarian.py --workdir artifacts/`date '+%m-%d_%H%M'`
Checkpoints (all non-DC5 variants) using the torchvision R50 backbone:
Checkpoint | GFLOPs | ||||||
---|---|---|---|---|---|---|---|
DETR-R50-1333* | 174.2 | 40.80 | 61.88 | 42.45 | 19.2 | 44.31 | 60.32 |
DETR-R50-640 | 38.5 | 33.14 | 52.89 | 34.00 | 10.54 | 35.10 | 55.53 |
*official DETR baseline, except that these models were trained for 300 epochs instead of 500 epochs.
- Download one of the pretrained checkpoints.
# In configs/common.py (or any) config.init_from = ml_collections.ConfigDict() config.init_from.checkpoint_path = '/path/to/checkpoint'
- Replace
config.total_epochs
withconfig.total_steps = 0
to skip to eval.
Large parts of this codebase were motivated by scenic.
Authors' implementation in PyTorch: facebookresearch/detr.