Skip to content

Commit

Permalink
upgrade spark to 2.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
titicaca committed Dec 27, 2018
1 parent e4e318c commit 537aa2a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>org.apache.spark</groupId>
<artifactId>spark-gbtlr</artifactId>
<version>2.3.0</version>
<version>2.4.0</version>

<properties>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
Expand All @@ -18,7 +18,7 @@
<log4j.version>1.2.17</log4j.version>
<skipTests>false</skipTests>
<maven.version>3.3.9</maven.version>
<spark.version>2.3.0</spark.version>
<spark.version>2.4.0</spark.version>
</properties>

<dependencies>
Expand Down
11 changes: 6 additions & 5 deletions src/main/scala/org/apache/spark/ml/gbtlr/GBTLRClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.apache.spark.ml.classification._
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.{DenseVector => OldDenseVector}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.configuration.{FeatureType, Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
Expand Down Expand Up @@ -672,7 +673,7 @@ class GBTLRClassifier (override val uid: String)
* @param dataset Input data.
* @return GBTLRClassification model.
*/
override def train(dataset: Dataset[_]): GBTLRClassificationModel = {
override def train(dataset: Dataset[_]): GBTLRClassificationModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
getCategoricalFeatures(dataset.schema($(featuresCol)))

Expand All @@ -691,10 +692,11 @@ class GBTLRClassifier (override val uid: String)
val boostingStrategy = new OldBoostingStrategy(strategy, getOldLossType,
getGBTMaxIter, getStepSize)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
instr.logPipelineStage(this)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(2)
instr.logDataset(dataset)
instr.logParams(this)

// train a gradient boosted tree model using boostingStrategy.
val gbtModel = GradientBoostedTrees.train(oldDataset, boostingStrategy)
Expand Down Expand Up @@ -737,7 +739,6 @@ class GBTLRClassifier (override val uid: String)
val summary = new GBTLRClassifierTrainingSummary(datasetWithCombinedFeatures, lrModel.summary,
gbtModel.trees, gbtModel.treeWeights)
model.setSummary(Some(summary))
instr.logSuccess(model)
model
}

Expand Down Expand Up @@ -927,7 +928,7 @@ object GBTLRClassificationModel extends MLReadable[GBTLRClassificationModel] {
val gbtModel = GradientBoostedTreesModel.load(sc, gbtDataPath)
val lrModel = LogisticRegressionModel.load(lrDataPath)
val model = new GBTLRClassificationModel(metadata.uid, gbtModel, lrModel)
DefaultParamsReader.getAndSetParams(model, metadata)
metadata.getAndSetParams(model)
model
}
}
Expand Down

0 comments on commit 537aa2a

Please sign in to comment.