CrossValidator with Pipeline Example
Caution
|
FIXME The example below does NOT work. Being investigated. |
Caution
|
FIXME Can k-means be crossvalidated? Does it make any sense? Does it only applies to supervised learning? |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
// Let's create a pipeline with transformers and estimator import org.apache.spark.ml.feature._ val tok = new Tokenizer().setInputCol("text") val hashTF = new HashingTF() .setInputCol(tok.getOutputCol) .setOutputCol("features") .setNumFeatures(10) import org.apache.spark.ml.classification.RandomForestClassifier val rfc = new RandomForestClassifier import org.apache.spark.ml.Pipeline val pipeline = new Pipeline() .setStages(Array(tok, hashTF, rfc)) // CAUTION: label must be double // 0 = scientific text // 1 = non-scientific text val trainDS = Seq( (0L, "[science] hello world", 0d), (1L, "long text", 1d), (2L, "[science] hello all people", 0d), (3L, "[science] hello hello", 0d)).toDF("id", "text", "label").cache // Check out the train dataset // Values in label and prediction columns should be alike val sampleModel = pipeline.fit(trainDS) sampleModel .transform(trainDS) .select('text, 'label, 'features, 'prediction) .show(truncate = false) +--------------------------+-----+--------------------------+----------+ |text |label|features |prediction| +--------------------------+-----+--------------------------+----------+ |[science] hello world |0.0 |(10,[0,8],[2.0,1.0]) |0.0 | |long text |1.0 |(10,[4,9],[1.0,1.0]) |1.0 | |[science] hello all people|0.0 |(10,[0,6,8],[1.0,1.0,2.0])|0.0 | |[science] hello hello |0.0 |(10,[0,8],[1.0,2.0]) |0.0 | +--------------------------+-----+--------------------------+----------+ val input = Seq("Hello ScienCE").toDF("text") sampleModel .transform(input) .select('text, 'rawPrediction, 'prediction) .show(truncate = false) +-------------+--------------------------------------+----------+ |text |rawPrediction |prediction| +-------------+--------------------------------------+----------+ |Hello ScienCE|[12.666666666666668,7.333333333333333]|0.0 | +-------------+--------------------------------------+----------+ import org.apache.spark.ml.tuning.ParamGridBuilder val paramGrid = new ParamGridBuilder().build import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator val binEval = new BinaryClassificationEvaluator import org.apache.spark.ml.tuning.CrossValidator val cv = new CrossValidator() .setEstimator(pipeline) // <-- pipeline is the estimator .setEvaluator(binEval) // has to match the estimator .setEstimatorParamMaps(paramGrid) // WARNING: It does not work!!! val cvModel = cv.fit(trainDS) |