Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ case object ModelInsights {
classOf[DataBalancerSummary], classOf[DataCutterSummary], classOf[DataSplitterSummary],
classOf[SingleMetric], classOf[MultiMetrics], classOf[BinaryClassificationMetrics],
classOf[BinaryClassificationBinMetrics], classOf[MulticlassThresholdMetrics],
classOf[BinaryThresholdMetrics], classOf[MultiClassificationMetrics], classOf[RegressionMetrics]
classOf[BinaryThresholdMetrics], classOf[MultiClassificationMetrics], classOf[RegressionMetrics],
classOf[MultiClassificationMetricsTopK],
classOf[MulticlassConfMatrixMetricsByThreshold], classOf[MisClassificationMetrics]
))
val evalMetricsSerializer = new CustomSerializer[EvalMetric](_ =>
( { case JString(s) => EvalMetric.withNameInsensitive(s) },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,34 @@ private[op] class OpMultiClassificationEvaluator
)
}

/**
* function to convert a sequence of ClassCount to a MisClassificationsPerCategory instance
*
* @param allClassCtSeq sequence of ClassCount containing labels or predictions and their counts for each category
* @param category index of a labelled or predicted class
* @return a MisClassifcationPerCategory instance
*/
private def getMisclassificationsPerCategory(
category: Double, allClassCtSeq: Seq[ClassCount]): MisClassificationsPerCategory = {

val misClassificationCtMap = allClassCtSeq
.filter(_.ClassIndex != category)
.sortBy(-_.Count)
.take($(confMatrixMinSupport))

val labelCount = allClassCtSeq.map(_.Count).reduce(_ + _)
val correctCount = allClassCtSeq.filter(_.ClassIndex == category)
.map(_.Count)
.reduceOption(_ + _).getOrElse(0L)

MisClassificationsPerCategory(
Category = category,
TotalCount = labelCount,
CorrectCount = correctCount,
MisClassifications = misClassificationCtMap
)
}

/**
* function to calculate the mostly frequently mis-classified classes for each label/prediction category
*
Expand All @@ -291,49 +319,15 @@ private[op] class OpMultiClassificationEvaluator
.reduceByKey(_ + _)

val misClassificationsByLabel = labelPredictionCountRDD.map {
case ((label, prediction), count) => (label, Seq((prediction, count)))
case ((label, prediction), count) => (label, Seq(ClassCount(prediction, count)))
}.reduceByKey(_ ++ _)
.map { case (label, predictionCountsIter) => {
val misClassificationCtMap = predictionCountsIter
.filter { case (pred, _) => pred != label }
.sortBy(-_._2)
.take($(confMatrixMinSupport)).toMap

val labelCount = predictionCountsIter.map(_._2).reduce(_ + _)
val correctCount = predictionCountsIter
.collect { case (pred, count) if pred == label => count }
.reduceOption(_ + _).getOrElse(0L)

MisClassificationsPerCategory(
Category = label,
TotalCount = labelCount,
CorrectCount = correctCount,
MisClassifications = misClassificationCtMap
)
}
}.sortBy(-_.TotalCount).collect()
.map { case (label, predictionCountsSeq) => getMisclassificationsPerCategory(label, predictionCountsSeq)}
.sortBy(-_.TotalCount).collect()

val misClassificationsByPrediction = labelPredictionCountRDD.map {
case ((label, prediction), count) => (prediction, Seq((label, count)))
case ((label, prediction), count) => (prediction, Seq(ClassCount(label, count)))
}.reduceByKey(_ ++ _)
.map { case (prediction, labelCountsIter) => {
val sortedMisclassificationCt = labelCountsIter
.filter { case (label, _) => label != prediction }
.sortBy(-_._2)
.take($(confMatrixMinSupport)).toMap

val predictionCount = labelCountsIter.map(_._2).reduce(_ + _)
val correctCount = labelCountsIter
.collect { case (label, count) if label == prediction => count }
.reduceOption(_ + _).getOrElse(0L)

MisClassificationsPerCategory(
Category = prediction,
TotalCount = predictionCount,
CorrectCount = correctCount,
MisClassifications = sortedMisclassificationCt
)
}
.map { case (prediction, labelCountsSeq) => getMisclassificationsPerCategory(prediction, labelCountsSeq)
}.sortBy(-_.TotalCount).collect()

MisClassificationMetrics(
Expand Down Expand Up @@ -541,10 +535,15 @@ case class MultiClassificationMetrics
*/
case class MultiClassificationMetricsTopK
(
@JsonDeserialize(contentAs = classOf[java.lang.Integer])
topKs: Seq[Int],
@JsonDeserialize(contentAs = classOf[java.lang.Double])
Precision: Seq[Double],
@JsonDeserialize(contentAs = classOf[java.lang.Double])
Recall: Seq[Double],
@JsonDeserialize(contentAs = classOf[java.lang.Double])
F1: Seq[Double],
@JsonDeserialize(contentAs = classOf[java.lang.Double])
Error: Seq[Double]
) extends EvaluationMetrics

Expand Down Expand Up @@ -594,8 +593,19 @@ case class MisClassificationsPerCategory
Category: Double,
TotalCount: Long,
CorrectCount: Long,
@JsonDeserialize(keyAs = classOf[java.lang.Double])
MisClassifications: Map[Double, Long]
MisClassifications: Seq[ClassCount]
)

/**
* container to store the count of a class
*
* @param ClassIndex
* @param Count
*/
case class ClassCount
(
ClassIndex: Double,
Count: Long
)

/**
Expand Down
59 changes: 57 additions & 2 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, OpXGBoostRegressor, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.stages.impl.selector.ValidationType._
import com.salesforce.op.stages.impl.selector.{SelectedCombinerModel, SelectedModel, SelectedModelCombiner}
import com.salesforce.op.stages.impl.tuning.{DataCutter, DataSplitter}
import com.salesforce.op.stages.impl.selector.{ModelEvaluation, ModelSelectorSummary, ProblemType, SelectedCombinerModel, SelectedModel, SelectedModelCombiner, ValidationType}
import com.salesforce.op.stages.impl.tuning.{DataBalancerSummary, DataCutter, DataSplitter}
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestFeatureBuilder}
import com.salesforce.op.testkit.RandomReal
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
Expand Down Expand Up @@ -406,6 +406,61 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
pretty should include("Top Contributions")
}

it should "correctly serialize and deserialize from json with MulticlassificationMetrics" in {
val trainMetrics = MultiClassificationMetrics(
Precision = 0.1,
Recall = 0.2,
F1 = 0.3,
Error = 0.4,
ThresholdMetrics = MulticlassThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
noPredictionCounts = Map(3 -> Seq(300L))),
TopKMetrics = MultiClassificationMetricsTopK(Seq(1), Seq(0.1), Seq(0.1), Seq(0.1), Seq(0.1)),
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold( 2, Seq(0.1), Seq(0.1), Seq(Seq(1L))),
MisClassificationMetrics = MisClassificationMetrics(1, Seq.empty,
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq(ClassCount(1.0, 3L)))))
)

val holdoutMetrics = MultiClassificationMetrics(
Precision = 0.1,
Recall = 0.2,
F1 = 0.3,
Error = 0.4,
ThresholdMetrics = MulticlassThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
noPredictionCounts = Map(3 -> Seq(300L))),
TopKMetrics = MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold(2, Seq(0.1), Seq(0.1), Seq.empty),
MisClassificationMetrics = MisClassificationMetrics(1, Seq.empty,
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq.empty)))
)

val summary = ModelSelectorSummary(
validationType = ValidationType.TrainValidationSplit,
validationParameters = Map.empty,
dataPrepParameters = Map.empty,
dataPrepResults = None,
evaluationMetric = MultiClassEvalMetrics.Error,
problemType = ProblemType.MultiClassification,
bestModelUID = "test1",
bestModelName = "test2",
bestModelType = "test3",
validationResults = Seq.empty,
trainEvaluation = trainMetrics,
holdoutEvaluation = Some(holdoutMetrics)
)

val insights = workflowModel.modelInsights(pred).copy(selectedModelInfo = Some(summary))
ModelInsights.fromJson(insights.toJson()) match {
case Failure(e) => fail(e)
case Success(deser) =>
insights.selectedModelInfo.toSeq.zip(deser.selectedModelInfo.toSeq).foreach {
case (o, i) =>
o.trainEvaluation shouldEqual i.trainEvaluation
o.holdoutEvaluation shouldEqual i.holdoutEvaluation
}
}
}

it should "correctly serialize and deserialize from json when raw feature filter is not used" in {
val insights = workflowModel.modelInsights(pred)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext

// create a test 2D array where 1st dimension is the label and 2nd dimension is the prediction,
// and the # of (label, prediction) equals to the value of the label
// _| 1 2 3
// 1| 1L 1L 1L
// 2| 2L 2L 2L
// 3| 3L 3L 3L
// ___| 1.0 2.0 3.0
// 1.0| 1L 1L 1L
// 2.0| 2L 2L 2L
// 3.0| 3L 3L 3L
val testLabels = Array(1.0, 2.0, 3.0)
val labelAndPrediction = testLabels.flatMap(label => {
testLabels.flatMap(pred => Seq.fill(label.toInt)((label, pred)))
Expand Down Expand Up @@ -437,10 +437,10 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext

// create a test 2D array with the count of each label & prediction combination as:
// row is label and column is prediction
// _| 1 2 3
// 1| 2L 3L 4L
// 2| 3L 4L 5L
// 3| 4L 5L 6L
// ___| 1.0 2.0 3.0
// 1.0| 2L 3L 4L
// 2.0| 3L 4L 5L
// 3.0| 4L 5L 6L
val testLabels = List(1.0, 2.0, 3.0)
val labelAndPrediction = testLabels.flatMap(label => {
testLabels.flatMap(pred => Seq.fill(label.toInt + pred.toInt)((label, pred)))
Expand All @@ -452,21 +452,21 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
outputMetrics.MisClassificationsByLabel shouldEqual
Seq(
MisClassificationsPerCategory(Category = 3.0, TotalCount = 15L, CorrectCount = 6L,
MisClassifications = Map(2.0 -> 5L, 1.0 -> 4L)),
MisClassifications = Seq(ClassCount(2.0, 5L), ClassCount(1.0, 4L))),
MisClassificationsPerCategory(Category = 2.0, TotalCount = 12L, CorrectCount = 4L,
MisClassifications = Map(3.0 -> 5L, 1.0 -> 3L)),
MisClassifications = Seq(ClassCount(3.0, 5L), ClassCount(1.0, 3L))),
MisClassificationsPerCategory(Category = 1.0, TotalCount = 9L, CorrectCount = 2L,
MisClassifications = Map(3.0 -> 4L, 2.0 -> 3L))
MisClassifications = Seq(ClassCount(3.0, 4L), ClassCount(2.0, 3L)))
)

outputMetrics.MisClassificationsByPrediction shouldEqual
Seq(
MisClassificationsPerCategory(Category = 3.0, TotalCount = 15L, CorrectCount = 6L,
MisClassifications = Map(2.0 -> 5L, 1.0 -> 4L)),
MisClassifications = Seq(ClassCount(2.0, 5L), ClassCount(1.0, 4L))),
MisClassificationsPerCategory(Category = 2.0, TotalCount = 12L, CorrectCount = 4L,
MisClassifications = Map(3.0 -> 5L, 1.0 -> 3L)),
MisClassifications = Seq(ClassCount(3.0, 5L), ClassCount(1.0, 3L))),
MisClassificationsPerCategory(Category = 1.0, TotalCount = 9L, CorrectCount = 2L,
MisClassifications = Map(3.0 -> 4L, 2.0 -> 3L))
MisClassifications = Seq(ClassCount(3.0, 4L), ClassCount(2.0, 3L)))
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn
info("Each feature vector should only have either three or four non-zero entries. One each from country and " +
"picklist, while currency can have either two (if it's null the currency column will be filled with the mean)" +
" or just one if it's not null.")
it("should pick between 1 and 4 of the features") {
all(parsed.map(_.size)) should (be >= 1 and be <= 4)
it("should pick between 0 and 4 of the features") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tovbinm this test is quite flaky. In one of the failed runs, the value was 0 and not between 1 and 4.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leahmcguire @Jauntbox do you have any ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's still an open PR to address this: #423

We could use a better test, but for now we can start by getting rid of the flakiness.

all(parsed.map(_.size)) should (be >= 0 and be <= 4)
}

// Grab the feature vector metadata for comparison against the LOCO record insights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
noPredictionCounts = Map(3 -> Seq(300L))),
TopKMetrics = MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold(1, Seq(1.0), Seq(0.0, 0.5), Seq.empty),
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold(1, Seq(1.0), Seq(0.0, 0.5), Seq(Seq(1L))),
MisClassificationMetrics = MisClassificationMetrics(1, Seq.empty,
Seq(MisClassificationsPerCategory(TotalCount = 5L, CorrectCount = 3L, Category = 1.0,
MisClassifications = Map(1.0 -> 2L))))),
MisClassifications = Seq(ClassCount(1.0, 2L)))))),
holdoutEvaluation = None
)

Expand All @@ -121,19 +121,23 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
}

it should "not hide the root cause of JSON parsing errors" in {
val evalMetrics = MultiClassificationMetrics(Precision = 0.1, Recall = 0.2, F1 = 0.3, Error = 0.4,
val evalMetrics = MultiClassificationMetrics(
Precision = 0.1,
Recall = 0.2,
F1 = 0.3,
Error = 0.4,
ThresholdMetrics = MulticlassThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
noPredictionCounts = Map(3 -> Seq(300L))),
TopKMetrics = MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold( 2, Seq(0.1), Seq(0.1),
Seq.empty),
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold( 2, Seq(0.1), Seq(0.1), Seq(Seq(1L))),
MisClassificationMetrics = MisClassificationMetrics(1,
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Map(1.0 -> 3L))),
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Map(1.0 -> 3L))))
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq(ClassCount(1.0, 3L)))),
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq(ClassCount(1.0, 3L)))))
)

val evalMetricsJson = evalMetrics.toJson()
println(1)
val roundTripEvalMetrics = ModelSelectorSummary.evalMetFromJson(
classOf[MultiClassificationMetrics].getName, evalMetricsJson).get
roundTripEvalMetrics shouldBe evalMetrics
Expand Down