Skip to content

Commit

Permalink
Merge pull request #1 from qupath/dnn-model
Browse files Browse the repository at this point in the history
Use DnnModel as optional input rather than ImageOp
  • Loading branch information
petebankhead committed Aug 7, 2021
2 parents a0470e1 + b09a16e commit fd7a9ec
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 40 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ This adds support for running the 2D version of StarDist nucleus detection devel

It is intended for the (at the time of writing) not-yet-released QuPath v0.3, and remains in a not-quite-complete state.

> **Note:** The implementation has changed from QuPath v0.2.
> Nucleus classifications are now supported, but tile padding is not.
> See the documentation for more details.
## Installing

To install the StarDist extension, download the latest `qupath-extension-stardist-[version].jar` file and drag it onto QuPath when it is running.
Expand Down
142 changes: 102 additions & 40 deletions src/main/java/qupath/ext/stardist/StarDist2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -68,7 +69,6 @@
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.classes.PathClassFactory;
import qupath.lib.regions.ImagePlane;
import qupath.lib.regions.Padding;
import qupath.lib.regions.RegionRequest;
import qupath.lib.roi.GeometryTools;
import qupath.lib.roi.interfaces.ROI;
Expand Down Expand Up @@ -108,7 +108,7 @@ public static class Builder {
private int nThreads = -1;

private String modelPath = null;
private ImageOp predictionOp = null;
private DnnModel<?> dnn = null;
private ColorTransform[] channels = new ColorTransform[0];

private double threshold = 0.5;
Expand Down Expand Up @@ -136,8 +136,6 @@ public static class Builder {

private boolean constrainToParent = true;

private int pad = 32;

private List<ImageOp> ops = new ArrayList<>();

private boolean includeProbability = false;
Expand All @@ -147,8 +145,8 @@ private Builder(String modelPath) {
this.ops.add(ImageOps.Core.ensureType(PixelType.FLOAT32));
}

private Builder(ImageOp predictionOp) {
this.predictionOp = predictionOp;
private Builder(DnnModel<?> dnn) {
this.dnn = dnn;
this.ops.add(ImageOps.Core.ensureType(PixelType.FLOAT32));
}

Expand Down Expand Up @@ -476,9 +474,12 @@ public Builder tileSize(int tileWidth, int tileHeight) {
* Amount to pad tiles to reduce boundary artifacts.
* @param pad padding in pixels; width and height of tiles will be increased by pad x 2.
* @return this builder
* @deprecated padding is no longer supported and this method has no effect
*/
@Deprecated
public Builder padding(int pad) {
this.pad = pad;
logger.warn("Tile padding is no longer supported - this method has no effect");
// this.pad = pad;
return this;
}

Expand Down Expand Up @@ -542,20 +543,18 @@ public Builder inputScale(double... values) {
public StarDist2D build() {
var stardist = new StarDist2D();

var padding = pad > 0 ? Padding.symmetric(pad) : Padding.empty();
// var padding = pad > 0 ? Padding.symmetric(pad) : Padding.empty();
var mergedOps = new ArrayList<>(ops);
var mlOp = this.predictionOp;
if (mlOp == null) {
var dnn = this.dnn;
if (dnn == null) {
var file = new File(modelPath);
if (!file.exists()) {
throw new IllegalArgumentException("I couldn't find the model file " + file.getAbsolutePath());
// TODO: In the future, search within the user directory
}
if (file.isFile()) {
try {
DnnModel<?> dnn = DnnTools.builder(modelPath)
dnn = DnnTools.builder(modelPath)
.build();
mlOp = ImageOps.ML.dnn(dnn, tileWidth, tileHeight, padding);
logger.debug("Loaded model {} with OpenCV DNN", modelPath);
} catch (Exception e) {
logger.error("Unable to load model file with OpenCV. If you intended to use TensorFlow, you need to have it on the classpath & provide the "
Expand All @@ -568,31 +567,29 @@ public StarDist2D build() {
// For backwards compatibility, we try to support TensorFlow if the extension is installed
var clsTF = Class.forName("qupath.ext.tensorflow.TensorFlowTools");
var method = clsTF.getMethod("createDnnModel", String.class);
var dnn = (DnnModel<?>)method.invoke(null, modelPath);
mlOp = ImageOps.ML.dnn(dnn, tileWidth, tileHeight, padding);
// var clsTF = Class.forName("qupath.tensorflow.TensorFlowTools");
// var method = clsTF.getMethod("createOp", String.class, int.class, int.class, Padding.class);
// mlOp = (ImageOp)method.invoke(null, modelPath, tileWidth, tileHeight, padding);
dnn = (DnnModel<?>)method.invoke(null, modelPath);
logger.debug("Loaded model {} with TensorFlow", modelPath);
// mlOp = TensorFlowTools.createOp(modelPath, tileWidth, tileHeight, padding);
} catch (Exception e) {
logger.error("Unable to load TensorFlow with reflection - are you sure it is available and on the classpath?");
logger.error(e.getLocalizedMessage(), e);
throw new RuntimeException("Unable to load StarDist model from " + modelPath, e);
}
}
}
mergedOps.add(mlOp);

// var mlOp = ImageOps.ML.dnn(dnn, tileWidth, tileHeight, padding);
// mergedOps.add(mlOp);
mergedOps.add(ImageOps.Core.ensureType(PixelType.FLOAT32));

stardist.op = ImageOps.buildImageDataOp(channels)
.appendOps(mergedOps.toArray(ImageOp[]::new));
stardist.dnn = dnn;
stardist.threshold = threshold;
stardist.pixelSize = pixelSize;
stardist.cellConstrainScale = cellConstrainScale;
stardist.cellExpansion = cellExpansion;
stardist.tileWidth = tileWidth-pad*2;
stardist.tileHeight = tileHeight-pad*2;
stardist.tileWidth = tileWidth;
stardist.tileHeight = tileHeight;
stardist.includeProbability = includeProbability;
stardist.ignoreCellOverlaps = ignoreCellOverlaps;
stardist.measureShape = measureShape;
Expand Down Expand Up @@ -624,6 +621,7 @@ public StarDist2D build() {
private double threshold;

private ImageDataOp op;
private DnnModel<?> dnn;
private double pixelSize;
private double cellExpansion;
private double cellConstrainScale;
Expand All @@ -648,7 +646,8 @@ public StarDist2D build() {
private Collection<ObjectMeasurements.Compartments> compartments;
private Collection<ObjectMeasurements.Measurements> measurements;


private AtomicBoolean firstRun = new AtomicBoolean(true);
private boolean cancelRuns = false;


/**
Expand Down Expand Up @@ -713,7 +712,8 @@ private void detectObjectsImpl(ImageData<BufferedImage> imageData, Collection<?
parents.stream().forEach(p -> detectObjects(imageData, p, false));
else
parents.parallelStream().forEach(p -> detectObjects(imageData, p, false));
// Fire a globel update event

// Fire a global update event
imageData.getHierarchy().fireHierarchyChangedEvent(imageData.getHierarchy());
}

Expand All @@ -728,9 +728,17 @@ private void detectObjectsImpl(ImageData<BufferedImage> imageData, Collection<?
private void detectObjectsImpl(ImageData<BufferedImage> imageData, PathObject parent, boolean fireUpdate) {
Objects.nonNull(parent);
// Lock early, so the user doesn't make modifications
boolean wasLocked = parent.isLocked();
parent.setLocked(true);

List<PathObject> detections = detectObjects(imageData, parent.getROI());
List<PathObject> detections = detectObjects(imageData, parent.getROI());

if (cancelRuns) {
logger.warn("StarDist detection cancelled for {}", parent);
if (!wasLocked)
parent.setLocked(false);
return;
}

parent.clearPathObjects();
parent.addPathObjects(detections);
Expand All @@ -754,6 +762,7 @@ public List<PathObject> detectObjects(ImageData<BufferedImage> imageData, ROI ro
resolution = resolution.createScaledInstance(downsample, downsample);
}
var opServer = ImageOps.buildServer(imageData, op, resolution, tileWidth, tileHeight);
// var opServer = ImageOps.buildServer(imageData, op, resolution, tileWidth-pad*2, tileHeight-pad*2);

RegionRequest request;
if (roi == null)
Expand All @@ -766,6 +775,8 @@ public List<PathObject> detectObjects(ImageData<BufferedImage> imageData, ROI ro

var tiles = opServer.getTileRequestManager().getTileRequests(request);
var mask = roi == null ? null : roi.getGeometry();



// Detect all potential nuclei
var server = imageData.getServer();
Expand All @@ -782,9 +793,12 @@ public List<PathObject> detectObjects(ImageData<BufferedImage> imageData, ROI ro
else
log("Detecting nuclei");
var nuclei = tiles.parallelStream()
.flatMap(t -> detectObjectsForTile(op, imageData, t.getRegionRequest(), tiles.size() > 1, mask).stream())
.flatMap(t -> detectObjectsForTile(op, dnn, imageData, t.getRegionRequest(), tiles.size() > 1, mask).stream())
.collect(Collectors.toList());

if (cancelRuns)
return Collections.emptyList();

// Filter nuclei again if we need to for resolving tile overlaps
if (tiles.size() > 1) {
log("Resolving nucleus overlaps");
Expand Down Expand Up @@ -981,10 +995,16 @@ private static Mat extractChannels(Mat mat, int... channels) {



private List<PotentialNucleus> detectObjectsForTile(ImageDataOp op, ImageData<BufferedImage> imageData, RegionRequest request, boolean excludeOnBounds, Geometry mask) {
private List<PotentialNucleus> detectObjectsForTile(ImageDataOp op, DnnModel<?> dnn, ImageData<BufferedImage> imageData, RegionRequest request, boolean excludeOnBounds, Geometry mask) {

List<PotentialNucleus> nuclei;

if (Thread.currentThread().isInterrupted())
cancelRuns = true;

if (cancelRuns)
Collections.emptyList();

try (@SuppressWarnings("unchecked")
var scope = new PointerScope()) {
Mat mat;
Expand All @@ -995,21 +1015,63 @@ private List<PotentialNucleus> detectObjectsForTile(ImageDataOp op, ImageData<Bu
return Collections.emptyList();
}

boolean isFirstRun = firstRun.getAndSet(false);

var output = dnn.convertAndPredict(Map.of(DnnModel.DEFAULT_INPUT_NAME, mat));
Mat matProb = null;
Mat matRays = null;
Mat matClassifications = null;
if (output.size() == 1) {
// Split channels to extract probability, ray and (possibly) classification images
var matOutput = output.values().iterator().next();
int nChannels = matOutput.channels();
int nClassifications = classifications == null ? 0 : classifications.size();
int nRays = nChannels - 1 - nClassifications;
matProb = extractChannels(matOutput, 0);
matRays = extractChannels(matOutput, range(1, nRays+1));
matClassifications = nClassifications == 0 ? null : extractChannels(matOutput, range(nRays+1, nChannels));
} else {
// Split output as needed
// We require that probabilities are single-channel, and there are more rays than classifications
for (var entry : output.entrySet()) {
var temp = entry.getValue();
if (temp.channels() == 1)
matProb = temp;
else if (matRays == null)
matRays = temp;
else {
if (temp.channels() > matRays.channels()) {
matClassifications = matRays;
matRays = temp;
} else
matClassifications = temp;
}
}
}

// Warn if we have weird dimensions on the first run
if (isFirstRun) {
if (classifications != null && !classifications.isEmpty()) {
int nClassifications = classifications.size();
int nChannels = matClassifications == null ? 0 : matClassifications.channels();
// We might not specify a background classification, but if we have very different numbers from the prediction we should report that
if (nClassifications > nChannels || nClassifications < nChannels-1)
logger.warn("{} classifications provided, {} available in the prediction", nClassifications, nChannels);
else
logger.debug("{} classifications provided, {} available in the prediction", nClassifications, nChannels);
}
}

// TODO: May need to consider padding!

// OpenCVTools.matToImagePlus(mat, "Prediction " + request).show();

// Split channels to extract probability, ray and (possibly) classification images
int nChannels = mat.channels();
int nClassifications = classifications == null ? 0 : classifications.size();
int nRays = nChannels - 1 - nClassifications;
Mat matProb = extractChannels(mat, 0);
Mat matRays = extractChannels(mat, range(1, nRays+1));
Mat matClassifications = nClassifications == 0 ? null : extractChannels(mat, range(nRays+1, nChannels));

// Depending upon model export, we might have a half resolution prediction that needs to be rescaled
long inputWidth = Math.round(request.getWidth() / request.getDownsample());
long inputHeight = Math.round(request.getHeight() / request.getDownsample());
double scaleX = Math.round(inputWidth / mat.cols());
double scaleY = Math.round(inputHeight / mat.rows());
double scaleX = Math.round((double)inputWidth / matProb.cols());
double scaleY = Math.round((double)inputHeight / matProb.rows());
if (scaleX != 1.0 || scaleY != 1.0) {
if (scaleX != 2.0 || scaleY != 2.0)
logger.warn("Unexpected StarDist rescaling x={}, y={}", scaleX, scaleY);
Expand Down Expand Up @@ -1051,14 +1113,14 @@ public static Builder builder(String modelPath) {


/**
* Create a builder to customize detection parameters, using a provided op for prediction.
* Create a builder to customize detection parameters, using a provided {@link DnnModel} for prediction.
* This provides a way to use an alternative machine learning library and model file, rather than the default
* (OpenCV or TensorFlow).
* @param predictionOp the op to use for prediction
* @param dnn the model to use for prediction
* @return
*/
public static Builder builder(ImageOp predictionOp) {
return new Builder(predictionOp);
public static Builder builder(DnnModel<?> dnn) {
return new Builder(dnn);
}


Expand Down

0 comments on commit fd7a9ec

Please sign in to comment.