CrossValidator — Model Tuning / Finding The Best Model
CrossValidator
is an Estimator for model tuning, i.e. finding the best model for given parameters and a dataset.
CrossValidator
splits the dataset into a set of non-overlapping randomly-partitioned numFolds pairs of training and validation datasets.
CrossValidator
generates a CrossValidatorModel
to hold the best model and average cross-validation metrics.
Note
|
CrossValidator takes any Estimator for model selection, including the Pipeline that is used to transform raw datasets and generate a Model.
|
Note
|
Use ParamGridBuilder for the parameter grid, i.e. collection of ParamMaps for model tuning.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import org.apache.spark.ml.Pipeline val pipeline: Pipeline = ... import org.apache.spark.ml.param.ParamMap val paramGrid: Array[ParamMap] = new ParamGridBuilder(). addGrid(...). addGrid(...). build import org.apache.spark.ml.tuning.CrossValidator val cv = new CrossValidator(). setEstimator(pipeline). setEvaluator(...). setEstimatorParamMaps(paramGrid). setNumFolds(...). setParallelism(...) import org.apache.spark.ml.tuning.CrossValidatorModel val bestModel: CrossValidatorModel = cv.fit(training) |
CrossValidator
is a MLWritable.
Parameter | Default Value | Description |
---|---|---|
(undefined) |
Estimator for best model selection. |
|
(undefined) |
Param maps for the estimator |
|
(undefined) |
Evaluator to select hyper-parameters that maximize the validated metric |
|
|
The number of folds for cross validation Must be at least |
|
|
The number of threads to use while fitting a model Must be at least |
|
Random seed |
Tip
|
Enable Add the following line to
Refer to Logging. |
Finding The Best Model — fit
Method
1 2 3 4 5 |
fit(dataset: Dataset[_]): CrossValidatorModel |
Note
|
fit is part of Estimator Contract to fit a model (i.e. produce a model).
|
fit
validates the schema (with logging turned on).
You should see the following DEBUG message in the logs:
1 2 3 4 5 |
DEBUG CrossValidator: Input schema: [json] |
fit
makes sure that estimator, evaluator, estimatorParamMaps and parallelism parameters are defined or reports a NoSuchElementException
.
1 2 3 4 5 |
java.util.NoSuchElementException: Failed to find a default value for [name] |
fit
creates a ExecutionContext
(per parallelism parameter).
fit
creates a Instrumentation
and requests it to print out the parameters numFolds, seed, parallelism to the logs.
1 2 3 4 5 |
INFO ...FIXME |
fit
requests Instrumentation
to print out the tuning parameters to the logs.
1 2 3 4 5 |
INFO ...FIXME |
Note
|
fit passes the underlying RDD of the dataset to kFolds.
|
fit
computes metrics for every pair of training and validation RDDs.
fit
calculates the average metrics over all kFolds.
You should see the following INFO message in the logs:
1 2 3 4 5 |
INFO Average cross-validation metrics: [metrics] |
fit
requests the Evaluator for the best cross-validation metric.
You should see the following INFO message in the logs:
1 2 3 4 5 6 7 |
INFO Best set of parameters: [estimatorParamMap] INFO Best cross-validation metric: [bestMetric]. |
fit
requests the Estimator to fit the best model (for the dataset
and the best set of estimatorParamMap).
You should see the following INFO message in the logs:
1 2 3 4 5 |
INFO training finished |
In the end, fit
creates a CrossValidatorModel
(for the ID, the best model and the average metrics for every kFold) and copies parameters to it.
fit and Computing Metric for Training and Validation RDDs
fit
computes metrics for every pair of training and validation RDDs (from kFold).
fit
creates and persists training and validation datasets.
Tip
|
You can monitor the storage for persisting the datasets in web UI’s Storage tab. |
fit
Prints out the following DEBUG message to the logs
1 2 3 4 5 |
DEBUG Train split [index] with multiple sets of parameters. |
For every map in estimatorParamMaps parameter fit
fits a model using the Estimator.
fit
does the fitting in parallel per parallelism parameter.
Note
|
parallelism parameter defaults to 1 , i.e. no parallelism for fitting models.
|
Note
|
fit unpersists the training data (per pair of training and validation RDDs) when all models have been trained.
|
fit
requests the models to transform their respective validation datasets (with the corresponding parameters from estimatorParamMaps) and then requests the Evaluator to evaluate the transformed datasets.
fit
prints out the following DEBUG message to the logs:
1 2 3 4 5 |
DEBUG Got metric [metric] for model trained with $paramMap. |
fit
waits until all metrics are available and unpersists the validation dataset.
Validating and Transforming Schema — transformSchema
Method
1 2 3 4 5 |
transformSchema(schema: StructType): StructType |
Note
|
transformSchema is part of PipelineStage Contract.
|
transformSchema
simply passes the call to transformSchemaImpl (that is shared between CrossValidator
and TrainValidationSplit).