-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] KMeans clustering #29
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package com.picnicml.doddlemodel.base | ||
|
||
import java.io.Serializable | ||
|
||
import com.picnicml.doddlemodel.data.{Features, Target} | ||
|
||
abstract class Clusterer[A <: Clusterer[A]] extends Estimator { | ||
this: A with Serializable => | ||
|
||
/** A function that creates an identical clusterer. */ | ||
protected def copy: A | ||
|
||
def predict(x: Features): Target = { | ||
require(this.isFitted, "Called predict on a model that is not trained yet") | ||
this.predictSafe(x) | ||
} | ||
|
||
/** A function that is guaranteed to be called on a fitted model. */ | ||
protected def predictSafe(x: Features): Target | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package com.picnicml.doddlemodel.clustering | ||
|
||
import java.io.Serializable | ||
|
||
import breeze.linalg.functions.euclideanDistance | ||
import breeze.linalg.{*, DenseMatrix, DenseVector, argmin, sum} | ||
import breeze.stats.mean | ||
import com.picnicml.doddlemodel.data.loadCsvDataset | ||
import com.picnicml.doddlemodel.data.{Features, Target} | ||
|
||
import scala.util.Random | ||
import scala.concurrent.ExecutionContext.Implicits.global | ||
import scala.annotation.tailrec | ||
import scala.concurrent.duration.Duration | ||
import scala.concurrent.{Await, Future} | ||
|
||
|
||
@SerialVersionUID(1L) | ||
class KMeansClustering private (val k: Int, | ||
val maxIterations: Int, | ||
val earlyStoppingPercentage: Double, | ||
val clusterCenters: Option[DenseMatrix[Double]], | ||
val labels: Option[DenseVector[Int]]) | ||
extends RandomizableClusterer[KMeansClustering] with Serializable { | ||
|
||
this: Serializable => | ||
/** A function that is guaranteed to be called on a fitted model. */ | ||
override protected def predictSafe(x: Features): Target = ??? | ||
|
||
override def isFitted: Boolean = this.clusterCenters.isDefined && this.labels.isDefined | ||
|
||
override def copy: KMeansClustering = new KMeansClustering(this.k, this.maxIterations, this.earlyStoppingPercentage, this.clusterCenters, this.labels) | ||
|
||
@tailrec | ||
private def step(x: Features, iter: Int, clusterCenters: DenseMatrix[Double], labels: DenseVector[Int]): (DenseMatrix[Double], DenseVector[Int]) = { | ||
if (iter == 0) return (clusterCenters, labels) | ||
|
||
val newClusterCenters = DenseMatrix.zeros[Double](this.k, x.cols) | ||
val newLabelsFutures = x(*, ::).toIndexedSeq map { row => | ||
Future { | ||
argmin(clusterCenters(*, ::).map(center => euclideanDistance(row.t, center))) | ||
} | ||
} | ||
|
||
val newLabels: DenseVector[Int] = DenseVector[Int](Await.result(Future.sequence(newLabelsFutures), Duration.Inf).toArray) | ||
newLabels.toScalaVector.zipWithIndex.groupBy(_._1).foreach { | ||
case (labelIdx, groupedLabels) => | ||
val rows = x(groupedLabels.map(_._2), ::) | ||
newClusterCenters(labelIdx, ::) := mean(rows(::, *)) | ||
} | ||
|
||
val sameValuesCount = sum((newLabels :== labels).mapValues(b => if (b) 1 else 0)) | ||
if (sameValuesCount / x.rows < this.earlyStoppingPercentage) { | ||
(newClusterCenters, newLabels) | ||
} else { | ||
step(x, iter - 1, newClusterCenters, newLabels) | ||
} | ||
} | ||
|
||
override def fitSafe(x: Features)(implicit rand: Random): KMeansClustering = { | ||
val initialClusterCenters = x((0 until k).map(_ => rand.nextInt(x.rows)), ::).toDenseMatrix | ||
val (clusterCenters, labels) = step(x, this.maxIterations, initialClusterCenters, DenseVector.fill[Int](x.rows){-1}) | ||
new KMeansClustering(this.k, this.maxIterations, this.earlyStoppingPercentage, Some(clusterCenters), Some(labels)) | ||
} | ||
} | ||
|
||
object KMeansClustering { | ||
|
||
def apply(k: Int, maxIterations: Int = 300, earlyStoppingPercentage: Double = 0.01): KMeansClustering = { | ||
require(k >= 2, "Number of clusters must be greater or equal to 2") | ||
require(maxIterations >= 1, "Number of maximum iterations must be greater or equal to 1") | ||
require(earlyStoppingPercentage >= 0.0 && earlyStoppingPercentage <= 1.0, "Early stopping percentage must be between 0 and 1.") | ||
new KMeansClustering(k, maxIterations, earlyStoppingPercentage, None, None) | ||
} | ||
|
||
def main(args: Array[String]): Unit = { | ||
// val kmeans = KMeansClustering(2).fit(DenseMatrix((3.0, 1.0, 5.0), (-1.0, -2.0, 5.0), (-1.3, -2.1, 5.0), (-2.0, -123.0, 5.0)), DenseVector(0.1, 0.2, 0.0, 0.5)) | ||
val data = loadCsvDataset("/Users/rok/Downloads/mnist_train.csv", headerLine = false) | ||
val xTr = data(::, 1 to -1) / 255.0 | ||
println("Starting...") | ||
val t0 = System.nanoTime() | ||
val kmeans = KMeansClustering(10).fit(xTr, DenseVector[Double]()) | ||
val t1 = System.nanoTime() - t0 | ||
println(t1 / Math.pow(10, 9), "seconds") | ||
|
||
kmeans.labels.get.toArray.zipWithIndex.groupBy(_._1).foreach { | ||
case (labelIdx, groupedLabels) => | ||
println(labelIdx, groupedLabels.length) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package com.picnicml.doddlemodel.clustering | ||
|
||
import java.io.Serializable | ||
|
||
import breeze.linalg.DenseVector | ||
import com.picnicml.doddlemodel.base.Clusterer | ||
import com.picnicml.doddlemodel.data.{Features, Target} | ||
|
||
import scala.util.Random | ||
|
||
trait RandomizableClusterer[A <: RandomizableClusterer[A]] extends Clusterer[A] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As already mentioned, |
||
this: A with Serializable => | ||
|
||
/** A function that creates an identical clusterer. */ | ||
protected def copy: A | ||
|
||
protected def fit(x: Features, y: Target = DenseVector[Double]())(implicit rand: Random = new Random()): A = { | ||
require(!this.isFitted, "Called fit on a model that is already trained") | ||
this.copy.fitSafe(x) | ||
} | ||
|
||
/** | ||
* The object is guaranteed not to be fitted. | ||
*/ | ||
protected def fitSafe(x: Features)(implicit rand: Random): A | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clusterer
currently only has a single publicly exposed function:def predict(x: Features): Target
. Additionally, I don't think all clustering algorithms will exposepredict
, e.g. DBSCAN doesn't. On the other hand, all of them probably needfit
exposed?