From bab00341a9c4072015dde0d94d301bf60cd16f3b Mon Sep 17 00:00:00 2001 From: Alex Herbert Date: Mon, 11 Mar 2024 21:55:54 +0000 Subject: [PATCH] Allow fitting single component data --- ...eNormalMixtureExpectationMaximization.java | 6 +- ...malMixtureExpectationMaximizationTest.java | 139 ++++++++++-------- 2 files changed, 84 insertions(+), 61 deletions(-) diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java index 8b51195ab..a3c7397d2 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java @@ -294,7 +294,7 @@ public void fit(MixtureMultivariateNormalDistribution initialMixture) * @return Multivariate normal mixture model estimated from the data * @throws NumberIsTooLargeException if {@code numComponents} is greater * than the number of data rows. - * @throws NumberIsTooSmallException if {@code numComponents < 2}. + * @throws NumberIsTooSmallException if {@code numComponents < 1}. * @throws NotStrictlyPositiveException if data has less than 2 rows * @throws DimensionMismatchException if rows of data have different numbers * of columns @@ -306,8 +306,8 @@ public static MixtureMultivariateNormalDistribution estimate(final double[][] da if (data.length < 2) { throw new NotStrictlyPositiveException(data.length); } - if (numComponents < 2) { - throw new NumberIsTooSmallException(numComponents, 2, true); + if (numComponents < 1) { + throw new NumberIsTooSmallException(numComponents, 1, true); } if (numComponents > data.length) { throw new NumberIsTooLargeException(numComponents, data.length, true); diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java index 801cff467..281036456 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java @@ -191,104 +191,127 @@ public void testInitialMixture() { } @Test - public void testFit() { - // Test that the loglikelihood, weights, and models are determined and - // fitted correctly + public void testFit2Dimensions2Components() { final double[][] data = getTestSamples(); - final double correctLogLikelihood = -4.292431006791994; - final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 }; - final double[][] correctMeans = new double[][]{ - {-1.4213112715121132, 1.6924690505757753}, - {4.213612224374709, 7.975621325853645} - }; + // Fit using the test samples using Matlab R2023b (Update 6): + // GMModel = fitgmdist(X,2); - final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2]; - correctCovMats[0] = new Array2DRowRealMatrix(new double[][] { - { 1.739356907285747, -0.5867644251487614 }, - { -0.5867644251487614, 1.0232932029324642 } } - ); - correctCovMats[1] = new Array2DRowRealMatrix(new double[][] { - { 4.245384898007161, 2.5797798966382155 }, - { 2.5797798966382155, 3.9200272522448367 } }); + // Expected results use the component order generated by the CM code for convenience + // i.e. ComponentProportion from matlab is reversed: [0.703722, 0.296278] - final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2]; - correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData()); - correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData()); + // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations) + final double logLikelihood = -4.292430883324220e+02 / data.length; + // ComponentProportion + final double[] weights = new double[] {0.2962324189652912, 0.7037675810347089}; + // mu + final double[][] means = new double[][]{ + {-1.421239458366293, 1.692604555824222}, + {4.213949861591596, 7.975974466776790} + }; + // Sigma + final double[][][] covar = new double[][][] { + {{1.739441346307267, -0.586740858187563}, + {-0.586740858187563, 1.023420964341543}}, + {{4.243780645051973, 2.578176622652551}, + {2.578176622652551, 3.918302056479298}} + }; - MultivariateNormalMixtureExpectationMaximization fitter - = new MultivariateNormalMixtureExpectationMaximization(data); + assertFit(data, 2, logLikelihood, weights, means, covar, 1e-3); + } - MixtureMultivariateNormalDistribution initialMix - = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2); - fitter.fit(initialMix); - MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel(); - List> components = fittedMix.getComponents(); + @Test + public void testFit1Dimension2Components() { + // Use only the first column of the test data + final double[][] data = Arrays.stream(getTestSamples()) + .map(x -> new double[] {x[0]}).toArray(double[][]::new); + + // Fit the first column of test samples using Matlab R2023b (Update 6): + // GMModel = fitgmdist(X,2); - Assert.assertEquals(correctLogLikelihood, - fitter.getLogLikelihood(), - Math.ulp(1d)); + // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations) + final double logLikelihood = -2.512197016873482e+02 / data.length; + // ComponentProportion + final double[] weights = new double[] {0.240510201974078, 0.759489798025922}; + // Since data has 1 dimension the means and covariances are single values + // mu + final double[][] means = new double[][]{ + {-1.736139126623031}, + {3.899886984922886} + }; + // Sigma + final double[][][] covar = new double[][][] { + {{1.371327786710623}}, + {{5.254286022455004}} + }; - int i = 0; - for (Pair component : components) { - final double weight = component.getFirst(); - final MultivariateNormalDistribution mvn = component.getSecond(); - final double[] mean = mvn.getMeans(); - final RealMatrix covMat = mvn.getCovariances(); - Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d)); - Assert.assertArrayEquals(correctMeans[i], mean, 0.0); - Assert.assertEquals(correctCovMats[i], covMat); - i++; - } + assertFit(data, 2, logLikelihood, weights, means, covar, 0.05); } @Test - public void testFit1() { - // Test that the fit can be performed on data with a single dimension + public void testFit1Dimension1Component() { // Use only the first column of the test data final double[][] data = Arrays.stream(getTestSamples()) .map(x -> new double[] {x[0]}).toArray(double[][]::new); // Fit the first column of test samples using Matlab R2023b (Update 6): - // GMModel = fitgmdist(X,2); + // GMModel = fitgmdist(X,1); // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations) - final double correctLogLikelihood = -2.512197016873482e+02 / data.length; + final double logLikelihood = -2.576329329354790e+02 / data.length; // ComponentProportion - final double[] correctWeights = new double[] {0.240510201974078, 0.759489798025922}; + final double[] weights = new double[] {1.0}; // Since data has 1 dimension the means and covariances are single values // mu - final double[] correctMeans = new double[] {-1.736139126623031, 3.899886984922886}; + final double[][] means = new double[][]{ + {2.544365206503801}, + }; // Sigma - final double[] correctCov = new double[] {1.371327786710623, 5.254286022455004}; + final double[][][] covar = new double[][][] { + {{10.122711799089901}}, + }; + assertFit(data, 1, logLikelihood, weights, means, covar, 1e-3); + } + + private static void assertFit(double[][] data, int numComponents, + double logLikelihood, double[] weights, + double[][] means, double[][][] covar, double relError) { MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data); MixtureMultivariateNormalDistribution initialMix - = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2); + = MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents); fitter.fit(initialMix); MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel(); List> components = fittedMix.getComponents(); - final double relError = 0.05; - Assert.assertEquals(correctLogLikelihood, - fitter.getLogLikelihood(), - Math.abs(correctLogLikelihood) * relError); + Assert.assertEquals(logLikelihood, + fitter.getLogLikelihood(), + Math.abs(logLikelihood) * relError); int i = 0; for (Pair component : components) { final double weight = component.getFirst(); final MultivariateNormalDistribution mvn = component.getSecond(); - final double[] mean = mvn.getMeans(); - final RealMatrix covMat = mvn.getCovariances(); - Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * relError); - Assert.assertEquals(correctMeans[i], mean[0], Math.abs(correctMeans[i]) * relError); - Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), correctCov[i] * relError); + Assert.assertEquals(weights[i], weight, weights[i] * relError); + assertArrayEquals(means[i], mvn.getMeans(), relError); + final double[][] c = mvn.getCovariances().getData(); + Assert.assertEquals(covar[i].length, c.length); + for (int j = 0; j < covar[i].length; j++) { + assertArrayEquals(covar[i][j], c[j], relError); + } i++; } } + private static void assertArrayEquals(double[] e, double[] a, double relError) { + Assert.assertEquals("length", e.length, a.length); + for (int i = 0; i < e.length; i++) { + Assert.assertEquals(e[i], a[i], Math.abs(e[i]) * relError); + } + } + private double[][] getTestSamples() { // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and // [4, 8.2]