Skip to content

Commit

Permalink
2541 adds whats new 0.6 (#2550)
Browse files Browse the repository at this point in the history
* adds whats new 0.6

Signed-off-by: Wenqi Li <[email protected]>

* update docs

Signed-off-by: Wenqi Li <[email protected]>

* add a list

Signed-off-by: Wenqi Li <[email protected]>

* update according to comments

Signed-off-by: Wenqi Li <[email protected]>

* fixes typos

Signed-off-by: Wenqi Li <[email protected]>

* Revert "update lmdbdataset (#2531)"

This reverts commit a980ae4.

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jul 8, 2021
1 parent 9a83660 commit 0ad9e73
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 54 deletions.
4 changes: 2 additions & 2 deletions docs/source/highlights.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ Distributed data parallel is an important feature of PyTorch to connect multiple

### 3. C++/CUDA optimized modules
To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA implementation are introduced as extensions of the PyTorch native implementations.
MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions):
- via `setuptools`, for modules such as `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`,
MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):
- via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.
- via just-in-time (JIT) compilation, for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.
The following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation:
![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ Technical documentation is available at `docs.monai.io <https://docs.monai.io>`_
:maxdepth: 1
:caption: Feature highlights

whatsnew.md
whatsnew_0_6.md
whatsnew_0_5.md
highlights.md

.. toctree::
Expand Down
7 changes: 6 additions & 1 deletion docs/source/whatsnew.md → docs/source/whatsnew_0_5.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# What's new in 0.5 🎉
# What's new in 0.5

- Invert spatial transforms and test-time augmentations
- Lesion detection in digital pathology
- DeepGrow modules for interactive segmentation
- Various usability improvements

## Invert spatial transforms and test-time augmentations
It is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) with the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space. We enhance almost all the spatial transforms with an `inverse` operation and release this experimental feature in v0.5. Users can easily invert all the spatial transforms for one transformed data item or a batch of data items. It also can be achieved within the workflows by using the `TransformInverter` handler.
Expand Down
62 changes: 62 additions & 0 deletions docs/source/whatsnew_0_6.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# What's new in 0.6 🎉🎉

- Decollating mini-batches as an essential post-processing step
- Pythonic APIs to load the pretrained models from Clara Train MMARs
- Enhancements of the base metric interfaces
- C++/CUDA extension modules via PyTorch JIT compilation
- Backward compatibility and enhanced continuous integration/continuous delivery


## Decollating mini-batches as an essential post-processing step
`decollate batch` is introduced in MONAI v0.6, to simplify the post-processing transforms and enable flexible following operations on a batch of model outputs.
It can decollate batched data (e.g. model inference predictions) into a list of tensors -- as an 'inverse' operation of `collate_fn` of the PyTorch data loader -- for the benefits such as:
- enabling postprocessing transforms for each item independently, for example, randomised transforms could be applied differently for each predicted item in a batch.
- simplifying the transform APIs and reducing the input validation burdens, because both the preprocessing and postprocessing transforms now only support the "channel-first" input format.
- enabling the transform inverse operation for data items in different original shapes, as the inverted items are in a list, instead of being stacked in a single tensor.
- allowing for both a "batch-first" tensor and a list of "channel-first" tensors for flexible metric computation.

A typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):
![decollate_batch](../images/decollate_batch.png)

[decollate batch tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb) shows a detailed usage example based on a PyTorch native workflow.


## Pythonic APIs to load the pretrained models from Clara Train MMARs
[The MMAR (Medical Model ARchive)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html)
defines a data structure for organizing all artifacts produced during the model development life cycle.
NVIDIA Clara provides [various MMARs of medical domain-specific models](https://ngc.nvidia.com/catalog/models?orderBy=scoreDESC&pageNumber=0&query=clara_pt&quickFilter=&filters=).
These MMARs include all the information about the model including configurations and scripts to provide a workspace to perform model development tasks. To better leverage the trained MMARs released on Nvidia GPU cloud, MONAI provides pythonic APIs to access them.

To demonstrate this new feature, a medical image segmentation tutorial is created within
[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb)).
It mainly produces the following figure to compare the loss curves and validation scores for
- training from scratch (the green line),
- applying pretrained MMAR weights without training (the magenta line),
- training from the MMAR model weights (the blue
line), according to the number of training epochs:

![transfer_mmar](../images/transfer_mmar.png)

The tutorial shows the capability of encapsulating the details of MMAR parsing, as well as the potential of using pretrained MMARs for transfer learning.
These APIs are also being integrated into AI-assisted interactive workflows to accelerate the manual annotating processes (e.g. via [project-MONAI/MONAILabel](https://github.com/Project-MONAI/MONAILabel)).

## Enhancements of the base metric interfaces
The base API for metrics is now enhanced to support the essential computation logic for both iteration and epoch-based metrics.
With this update, the MONAI metrics module becomes more extensible, and thus a good starting point for customised metrics.
The APIs also by default support data parallel computation and consider the computation efficiency: with a `Cumulative` base class, intermediate metric outcomes can be automatically buffered, cumulated, synced across distributed processes, and aggregated for the final results. The [multi-processing computation example](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py) shows how to compute metrics based on saved predictions and labels in multi-processing environment.

## C++/CUDA extension modules via PyTorch JIT compilation
To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA modules are introduced as extensions of the PyTorch native implementation.
It now provides modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):
- via `setuptools` (since MONAI v0.5), for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.
- via just-in-time (JIT) compilation (since MONAI v0.6), for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.
The following figure shows results of MONAI's Gaussian mixture models applied to a tissue and surgical tools segmentation task:
![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)

## Backward compatibility and enhanced continuous integration/continuous delivery
Starting from this version, we experiment with basic policies of backward compatibility.
New utilities are introduced on top of the existing semantic versioning modules, and the git branching model.

At the same time, we actively analyze efficient, scalable, and secure CI/CD solutions to accommodate fast and collaborative codebase development.

Although a complete mechanism is still under development, These provide another essential step towards API-stable versions of MONAI, sustainable release cycles, and efficient open-source collaborations.
76 changes: 30 additions & 46 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,7 @@ def __init__(
self.lmdb_kwargs = lmdb_kwargs or {}
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size
self._env = None
# lmdb is single-writer multi-reader by default
# the cache is created without multi-threading
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
self._read_env = None
print(f"Accessing lmdb file: {self.db_file.absolute()}.")

def set_data(self, data: Sequence):
Expand All @@ -425,56 +422,43 @@ def set_data(self, data: Sequence):
"""
super().set_data(data=data)
self._env = None
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
self._read_env = None

def _fill_cache_start_reader(self, show_progress=True):
"""
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
This method can be used with multiple processes, but it may have a negative impact on the performance.
Args:
show_progress: whether to show the progress bar if possible.
"""
def _fill_cache_start_reader(self):
# create cache
self.lmdb_kwargs["readonly"] = False
if self._env is None:
self._env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
env = self._env
if show_progress and not has_tqdm:
env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
if self.progress and not has_tqdm:
warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.")
with env.begin(write=False) as search_txn:
for item in tqdm(self.data) if has_tqdm and show_progress else self.data:
key = self.hash_func(item)
done, retry, val = False, 5, None
while not done and retry > 0:
try:
with search_txn.cursor() as cursor:
for item in tqdm(self.data) if has_tqdm and self.progress else self.data:
key = self.hash_func(item)
done, retry, val = False, 5, None
while not done and retry > 0:
try:
with env.begin(write=True) as txn:
with txn.cursor() as cursor:
done = cursor.set_key(key)
if done:
continue
if done:
continue
if val is None:
val = self._pre_transform(deepcopy(item)) # keep the original hashed
val = pickle.dumps(val, protocol=self.pickle_protocol)
with env.begin(write=True) as txn:
txn.put(key, val)
done = True
except lmdb.MapFullError:
done, retry = False, retry - 1
size = env.info()["map_size"]
new_size = size * 2
warnings.warn(
f"Resizing the cache database from {int(size) >> 20}MB" f" to {int(new_size) >> 20}MB."
)
env.set_mapsize(new_size)
except lmdb.MapResizedError:
# the mapsize is increased by another process
# set_mapsize with a size of 0 to adopt the new size
env.set_mapsize(0)
if not done: # still has the map full error
txn.put(key, val)
done = True
except lmdb.MapFullError:
done, retry = False, retry - 1
size = env.info()["map_size"]
env.close()
raise ValueError(f"LMDB map size reached, increase size above current size of {size}.")
new_size = size * 2
warnings.warn(f"Resizing the cache database from {int(size) >> 20}MB to {int(new_size) >> 20}MB.")
env.set_mapsize(new_size)
except lmdb.MapResizedError:
# the mapsize is increased by another process
# set_mapsize with a size of 0 to adopt the new size,
env.set_mapsize(0)
if not done: # still has the map full error
size = env.info()["map_size"]
env.close()
raise ValueError(f"LMDB map size reached, increase size above current size of {size}.")
size = env.info()["map_size"]
env.close()
# read-only database env
Expand All @@ -492,7 +476,7 @@ def _cachecheck(self, item_transformed):
"""
if self._read_env is None:
self._read_env = self._fill_cache_start_reader(show_progress=False)
self._read_env = self._fill_cache_start_reader()
with self._read_env.begin(write=False) as txn:
data = txn.get(self.hash_func(item_transformed))
if data is None:
Expand Down
5 changes: 1 addition & 4 deletions tests/test_lmdbdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,12 @@ def test_shape(self, transform, expected_shape, kwargs=None):
"extra": os.path.join(tempdir, "test_extra2_new.nii.gz"),
},
]
dataset_postcached.set_data(data=test_data_new)
# test new exchanged cache content
if transform is None:
dataset_postcached.set_data(data=test_data_new)
self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz"))
self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz"))
self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz"))
else:
with self.assertRaises(RuntimeError):
dataset_postcached.set_data(data=test_data_new) # filename list updated, files do not exist


@skip_if_windows
Expand Down

0 comments on commit 0ad9e73

Please sign in to comment.