Example — Text Classification
The example was inspired by the video Building, Debugging, and Tuning Spark Machine Learning Pipelines – Joseph Bradley (Databricks). |
Problem: Given a text document, classify it as a scientific or non-scientific one.
The example uses a case class LabeledText to have the schema described nicely.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import spark.implicits._ sealed trait Category case object Scientific extends Category case object NonScientific extends Category // FIXME: Define schema for Category case class LabeledText(id: Long, category: Category, text: String) val data = Seq(LabeledText(0, Scientific, "hello world"), LabeledText(1, NonScientific, "witaj swiecie")).toDF scala> data.show +-----+-------------+ |label| text| +-----+-------------+ | 0| hello world| | 1|witaj swiecie| +-----+-------------+ |
It is then tokenized and transformed into another DataFrame with an additional column called features that is a Vector
of numerical values.
Paste the code below into Spark Shell using :paste mode.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import spark.implicits._ case class Article(id: Long, topic: String, text: String) val articles = Seq( Article(0, "sci.math", "Hello, Math!"), Article(1, "alt.religion", "Hello, Religion!"), Article(2, "sci.physics", "Hello, Physics!"), Article(3, "sci.math", "Hello, Math Revised!"), Article(4, "sci.math", "Better Math"), Article(5, "alt.religion", "TGIF")).toDS |
Now, the tokenization part comes that maps the input text of each text document into tokens (a Seq[String]
) and then into a Vector
of numerical values that can only then be understood by a machine learning algorithm (that operates on Vector
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
scala> articles.show +---+------------+--------------------+ | id| topic| text| +---+------------+--------------------+ | 0| sci.math| Hello, Math!| | 1|alt.religion| Hello, Religion!| | 2| sci.physics| Hello, Physics!| | 3| sci.math|Hello, Math Revised!| | 4| sci.math| Better Math| | 5|alt.religion| TGIF| +---+------------+--------------------+ val topic2Label: Boolean => Double = isSci => if (isSci) 1 else 0 val toLabel = udf(topic2Label) val labelled = articles.withColumn("label", toLabel($"topic".like("sci%"))).cache val Array(trainDF, testDF) = labelled.randomSplit(Array(0.75, 0.25)) scala> trainDF.show +---+------------+--------------------+-----+ | id| topic| text|label| +---+------------+--------------------+-----+ | 1|alt.religion| Hello, Religion!| 0.0| | 3| sci.math|Hello, Math Revised!| 1.0| +---+------------+--------------------+-----+ scala> testDF.show +---+------------+---------------+-----+ | id| topic| text|label| +---+------------+---------------+-----+ | 0| sci.math| Hello, Math!| 1.0| | 2| sci.physics|Hello, Physics!| 1.0| | 4| sci.math| Better Math| 1.0| | 5|alt.religion| TGIF| 0.0| +---+------------+---------------+-----+ |
The train a model phase uses the logistic regression machine learning algorithm to build a model and predict label
for future input text documents (and hence classify them as scientific or non-scientific).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import org.apache.spark.ml.feature.RegexTokenizer val tokenizer = new RegexTokenizer() .setInputCol("text") .setOutputCol("words") import org.apache.spark.ml.feature.HashingTF val hashingTF = new HashingTF() .setInputCol(tokenizer.getOutputCol) // it does not wire transformers -- it's just a column name .setOutputCol("features") .setNumFeatures(5000) import org.apache.spark.ml.classification.LogisticRegression val lr = new LogisticRegression().setMaxIter(20).setRegParam(0.01) import org.apache.spark.ml.Pipeline val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr)) |
It uses two columns, namely label
and features
vector to build a logistic regression model to make predictions.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
val model = pipeline.fit(trainDF) val trainPredictions = model.transform(trainDF) val testPredictions = model.transform(testDF) scala> trainPredictions.select('id, 'topic, 'text, 'label, 'prediction).show +---+------------+--------------------+-----+----------+ | id| topic| text|label|prediction| +---+------------+--------------------+-----+----------+ | 1|alt.religion| Hello, Religion!| 0.0| 0.0| | 3| sci.math|Hello, Math Revised!| 1.0| 1.0| +---+------------+--------------------+-----+----------+ // Notice that the computations add new columns scala> trainPredictions.printSchema root |-- id: long (nullable = false) |-- topic: string (nullable = true) |-- text: string (nullable = true) |-- label: double (nullable = true) |-- words: array (nullable = true) | |-- element: string (containsNull = true) |-- features: vector (nullable = true) |-- rawPrediction: vector (nullable = true) |-- probability: vector (nullable = true) |-- prediction: double (nullable = true) import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator val evaluator = new BinaryClassificationEvaluator().setMetricName("areaUnderROC") import org.apache.spark.ml.param.ParamMap val evaluatorParams = ParamMap(evaluator.metricName -> "areaUnderROC") scala> val areaTrain = evaluator.evaluate(trainPredictions, evaluatorParams) areaTrain: Double = 1.0 scala> val areaTest = evaluator.evaluate(testPredictions, evaluatorParams) areaTest: Double = 0.6666666666666666 |
Let’s tune the model’s hyperparameters (using “tools” from org.apache.spark.ml.tuning package).
FIXME Review the available classes in the org.apache.spark.ml.tuning package. |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import org.apache.spark.ml.tuning.ParamGridBuilder val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(100, 1000)) .addGrid(lr.regParam, Array(0.05, 0.2)) .addGrid(lr.maxIter, Array(5, 10, 15)) .build // That gives all the combinations of the parameters paramGrid: Array[org.apache.spark.ml.param.ParamMap] = Array({ logreg_cdb8970c1f11-maxIter: 5, hashingTF_8d7033d05904-numFeatures: 100, logreg_cdb8970c1f11-regParam: 0.05 }, { logreg_cdb8970c1f11-maxIter: 5, hashingTF_8d7033d05904-numFeatures: 1000, logreg_cdb8970c1f11-regParam: 0.05 }, { logreg_cdb8970c1f11-maxIter: 10, hashingTF_8d7033d05904-numFeatures: 100, logreg_cdb8970c1f11-regParam: 0.05 }, { logreg_cdb8970c1f11-maxIter: 10, hashingTF_8d7033d05904-numFeatures: 1000, logreg_cdb8970c1f11-regParam: 0.05 }, { logreg_cdb8970c1f11-maxIter: 15, hashingTF_8d7033d05904-numFeatures: 100, logreg_cdb8970c1f11-regParam: 0.05 }, { logreg_cdb8970c1f11-maxIter: 15, hashingTF_8d7033d05904-numFeatures: 1000, logreg_cdb8970c1f11-... import org.apache.spark.ml.tuning.CrossValidator import org.apache.spark.ml.param._ val cv = new CrossValidator() .setEstimator(pipeline) .setEstimatorParamMaps(paramGrid) .setEvaluator(evaluator) .setNumFolds(10) val cvModel = cv.fit(trainDF) |
Let’s use the cross-validated model to calculate predictions and evaluate their precision.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
val cvPredictions = cvModel.transform(testDF) scala> cvPredictions.select('topic, 'text, 'prediction).show +------------+---------------+----------+ | topic| text|prediction| +------------+---------------+----------+ | sci.math| Hello, Math!| 0.0| | sci.physics|Hello, Physics!| 0.0| | sci.math| Better Math| 1.0| |alt.religion| TGIF| 0.0| +------------+---------------+----------+ scala> evaluator.evaluate(cvPredictions, evaluatorParams) res26: Double = 0.6666666666666666 scala> val bestModel = cvModel.bestModel bestModel: org.apache.spark.ml.Model[_] = pipeline_8873b744aac7 |
FIXME Review https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala |
You can eventually save the model for later use.
1 2 3 4 5 |
cvModel.write.overwrite.save("model") |
Congratulations! You’re done.