Skip to content

Code for the paper "When More is Less: Incorporating Additional Datasets Can Hurt Performance By Introducing Spurious Correlations"

License

Notifications You must be signed in to change notification settings

basedrhys/ood-generalization

Repository files navigation

When More is Less: Incorporating Additional Datasets Can Hurt Performance By Introducing Spurious Correlations

Code for the paper "When More is Less: Incorporating Additional Datasets Can Hurt Performance By Introducing Spurious Correlations" by Rhys Compton, Lily Zhang, Aahlad Puli, and Rajesh Ranganath

ArXiv Preprint

High Level Overview

Acknowledgements

The training harness is heavily based on the excellent ClinicalDG repo which is in turn a modified version of DomainBed.

To replicate the experiments in the paper:

Step 0: Environment and Prerequisites

Run the following commands to clone this repo and create the Conda environment:

Step 1: Obtaining the Data

See DataSources.md for detailed instructions.

Step 2: Running Experiments

Experiments can be ran using the same procedure as for the DomainBed framework, with a few additional adjustable data hyperparameters which should be passed in as a JSON formatted dictionary.

For example, to train a single model:

python -m clinicaldg.scripts.train\
	--algorithm ERM\
	--dataset eICUSubsampleUnobs\
	--es_method val\
	--hparams  '{"eicu_architecture": "GRU", "eicu_subsample_g1_mean": 0.5, "eicu_subsample_g2_mean": 0.05}'\
	--output_dir /path/to/output

A detailed list of hparams available for each dataset can be found here.

We provide the bash scripts used for our main experiments in the bash_scripts directory. You will likely need to customize them, along with the launcher, to your compute environment.

W+B Sweeps

This codebase heavily utilises W+B to run experiments, both for tracking and recording results, and running experiments via the Sweeps feature (along side Slurm arrays).

The process for this is as follows:

  • Define your sweep hyperparameters in a .yaml file (e.g., this YAML file for image size experimentation)
  • Start the sweep: wandb sweep <yaml filename>
  • Create your Slurm array script (e.g., sweep.sbatch) -- the key parameter is the --array= feature, which should be set to 0-<num parameter configs>
  • Start the Slurm array job: sbatch sweeps/sweep.sbatch
  • Sit back and watch GPUs go brrrr

Loading the Datasets

The following steps walkthrough the process for loading datasets and applying hospital-label balancing

  • Example wrapper script: sweeps/sweep.sbatch (this contains paths to the dataset files / loading them via singularity)

  • Main entrypoint: clinicaldg/scripts/train.py

  • Corresponding W+B .yaml file for hospital-label balancing: sweeps/2d_nurd_fix.yaml

    • The key parameter is :--balance_resample "label_notest,under", which does label-balancing with no target test hospital (i.e., label balance the two hospitals to each other), via undersampling. "hospital_label,under" would be a more appropriate name, in hindsight.

    • The --test_env is not used but just provided as a placeholder

    • This file gives a range of example invocations of the script, e.g. the following will train a model for Pneumonia prediction on MIMIC and CXP, balancing them such that P(Y = 1 | Hospital = MIMIC) == P(Y = 1 | Hospital = CXP):

      python ./clinicaldg/scripts/train.py --max_steps 20001 --train_envs MIMIC,CXP --test_env MIMIC --balance_resample "label_notest,under" --binary_label Pneumonia
  • The codebase is very generalized (to be used in both the eICU task and CXR classification) so has a lot of code to sift through. The path to dataset loading (and ultimately balancing) is from train.py, then through:

    • dataset = ds_class(hparams, args), which invokes...
    • CXRBase __init__() (in clinicaldg/datasets.py)
    • Lines 411 to 447 in clinicaldg/datasets.py do the actual hospital-label balancing

If wanting to use this data within another codebase, one could save the train/val/test DFs to CSV files after processing is finished, i.e., at line 496; these contain the file paths / labels and can be easily loaded into another python project.

Dataloader

The PyTorch dataset can be found in clinicaldg/cxr/data.py, and the get_dataset() function in that file also.

The other key class is the InifiniteDataLoader used in train.py Line 303. One of these is created for each hospital we're training on, and equal sized batches are sampled from each at every iteration.

About

Code for the paper "When More is Less: Incorporating Additional Datasets Can Hurt Performance By Introducing Spurious Correlations"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published