Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] KMeans clustering #29

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/main/scala/com/picnicml/doddlemodel/base/Clusterer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.picnicml.doddlemodel.base

import java.io.Serializable

import com.picnicml.doddlemodel.data.{Features, Target}

abstract class Clusterer[A <: Clusterer[A]] extends Estimator {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clusterer currently only has a single publicly exposed function: def predict(x: Features): Target. Additionally, I don't think all clustering algorithms will expose predict, e.g. DBSCAN doesn't. On the other hand, all of them probably need fit exposed?

this: A with Serializable =>

/** A function that creates an identical clusterer. */
protected def copy: A

def predict(x: Features): Target = {
require(this.isFitted, "Called predict on a model that is not trained yet")
this.predictSafe(x)
}

/** A function that is guaranteed to be called on a fitted model. */
protected def predictSafe(x: Features): Target
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package com.picnicml.doddlemodel.clustering

import java.io.Serializable

import breeze.linalg.functions.euclideanDistance
import breeze.linalg.{*, DenseMatrix, DenseVector, argmin, sum}
import breeze.stats.mean
import com.picnicml.doddlemodel.data.loadCsvDataset
import com.picnicml.doddlemodel.data.{Features, Target}

import scala.util.Random
import scala.concurrent.ExecutionContext.Implicits.global
import scala.annotation.tailrec
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}


@SerialVersionUID(1L)
class KMeansClustering private (val k: Int,
val maxIterations: Int,
val earlyStoppingPercentage: Double,
val clusterCenters: Option[DenseMatrix[Double]],
val labels: Option[DenseVector[Int]])
extends RandomizableClusterer[KMeansClustering] with Serializable {

this: Serializable =>
/** A function that is guaranteed to be called on a fitted model. */
override protected def predictSafe(x: Features): Target = ???

override def isFitted: Boolean = this.clusterCenters.isDefined && this.labels.isDefined

override def copy: KMeansClustering = new KMeansClustering(this.k, this.maxIterations, this.earlyStoppingPercentage, this.clusterCenters, this.labels)

@tailrec
private def step(x: Features, iter: Int, clusterCenters: DenseMatrix[Double], labels: DenseVector[Int]): (DenseMatrix[Double], DenseVector[Int]) = {
if (iter == 0) return (clusterCenters, labels)

val newClusterCenters = DenseMatrix.zeros[Double](this.k, x.cols)
val newLabelsFutures = x(*, ::).toIndexedSeq map { row =>
Future {
argmin(clusterCenters(*, ::).map(center => euclideanDistance(row.t, center)))
}
}

val newLabels: DenseVector[Int] = DenseVector[Int](Await.result(Future.sequence(newLabelsFutures), Duration.Inf).toArray)
newLabels.toScalaVector.zipWithIndex.groupBy(_._1).foreach {
case (labelIdx, groupedLabels) =>
val rows = x(groupedLabels.map(_._2), ::)
newClusterCenters(labelIdx, ::) := mean(rows(::, *))
}

val sameValuesCount = sum((newLabels :== labels).mapValues(b => if (b) 1 else 0))
if (sameValuesCount / x.rows < this.earlyStoppingPercentage) {
(newClusterCenters, newLabels)
} else {
step(x, iter - 1, newClusterCenters, newLabels)
}
}

override def fitSafe(x: Features)(implicit rand: Random): KMeansClustering = {
val initialClusterCenters = x((0 until k).map(_ => rand.nextInt(x.rows)), ::).toDenseMatrix
val (clusterCenters, labels) = step(x, this.maxIterations, initialClusterCenters, DenseVector.fill[Int](x.rows){-1})
new KMeansClustering(this.k, this.maxIterations, this.earlyStoppingPercentage, Some(clusterCenters), Some(labels))
}
}

object KMeansClustering {

def apply(k: Int, maxIterations: Int = 300, earlyStoppingPercentage: Double = 0.01): KMeansClustering = {
require(k >= 2, "Number of clusters must be greater or equal to 2")
require(maxIterations >= 1, "Number of maximum iterations must be greater or equal to 1")
require(earlyStoppingPercentage >= 0.0 && earlyStoppingPercentage <= 1.0, "Early stopping percentage must be between 0 and 1.")
new KMeansClustering(k, maxIterations, earlyStoppingPercentage, None, None)
}

def main(args: Array[String]): Unit = {
// val kmeans = KMeansClustering(2).fit(DenseMatrix((3.0, 1.0, 5.0), (-1.0, -2.0, 5.0), (-1.3, -2.1, 5.0), (-2.0, -123.0, 5.0)), DenseVector(0.1, 0.2, 0.0, 0.5))
val data = loadCsvDataset("/Users/rok/Downloads/mnist_train.csv", headerLine = false)
val xTr = data(::, 1 to -1) / 255.0
println("Starting...")
val t0 = System.nanoTime()
val kmeans = KMeansClustering(10).fit(xTr, DenseVector[Double]())
val t1 = System.nanoTime() - t0
println(t1 / Math.pow(10, 9), "seconds")

kmeans.labels.get.toArray.zipWithIndex.groupBy(_._1).foreach {
case (labelIdx, groupedLabels) =>
println(labelIdx, groupedLabels.length)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.picnicml.doddlemodel.clustering

import java.io.Serializable

import breeze.linalg.DenseVector
import com.picnicml.doddlemodel.base.Clusterer
import com.picnicml.doddlemodel.data.{Features, Target}

import scala.util.Random

trait RandomizableClusterer[A <: RandomizableClusterer[A]] extends Clusterer[A] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As already mentioned, fit here shouldn't be protected. On the other hand that introduces another problem: API (public functions defined in abstract classes and traits) should be in the base package. If you take a look at the linear.LinearModel and other traits there, none of them define any public interface, they only encapsulate functionality that is common to all linear models. We should identify what is common to all clustering estimators and expose that in the form of a base class(es).

this: A with Serializable =>

/** A function that creates an identical clusterer. */
protected def copy: A

protected def fit(x: Features, y: Target = DenseVector[Double]())(implicit rand: Random = new Random()): A = {
require(!this.isFitted, "Called fit on a model that is already trained")
this.copy.fitSafe(x)
}

/**
* The object is guaranteed not to be fitted.
*/
protected def fitSafe(x: Features)(implicit rand: Random): A
}