关注 spark技术分享,
撸spark源码 玩spark最佳实践

CrossValidator — Model Tuning / Finding The Best Model

CrossValidator — Model Tuning / Finding The Best Model

CrossValidator is an Estimator for model tuning, i.e. finding the best model for given parameters and a dataset.

CrossValidator splits the dataset into a set of non-overlapping randomly-partitioned numFolds pairs of training and validation datasets.

CrossValidator generates a CrossValidatorModel to hold the best model and average cross-validation metrics.

Note
CrossValidator takes any Estimator for model selection, including the Pipeline that is used to transform raw datasets and generate a Model.
Note
Use ParamGridBuilder for the parameter grid, i.e. collection of ParamMaps for model tuning.

CrossValidator is a MLWritable.

Table 1. CrossValidator’ Parameters
Parameter Default Value Description

estimator

(undefined)

Estimator for best model selection.

estimatorParamMaps

(undefined)

Param maps for the estimator

evaluator

(undefined)

Evaluator to select hyper-parameters that maximize the validated metric

numFolds

3

The number of folds for cross validation

Must be at least 2.

parallelism

1

The number of threads to use while fitting a model

Must be at least 1.

seed

Random seed

Tip

Enable INFO or DEBUG logging levels for org.apache.spark.ml.tuning.CrossValidator logger to see what happens inside.

Add the following line to conf/log4j.properties:

Refer to Logging.

Finding The Best Model — fit Method

Note
fit is part of Estimator Contract to fit a model (i.e. produce a model).

fit validates the schema (with logging turned on).

You should see the following DEBUG message in the logs:

fit makes sure that estimator, evaluator, estimatorParamMaps and parallelism parameters are defined or reports a NoSuchElementException.

fit creates a ExecutionContext (per parallelism parameter).

fit creates a Instrumentation and requests it to print out the parameters numFolds, seed, parallelism to the logs.

fit requests Instrumentation to print out the tuning parameters to the logs.

fit kFolds the RDD of the dataset per numFolds and seed parameters.

Note
fit passes the underlying RDD of the dataset to kFolds.

fit computes metrics for every pair of training and validation RDDs.

fit calculates the average metrics over all kFolds.

You should see the following INFO message in the logs:

fit requests the Evaluator for the best cross-validation metric.

You should see the following INFO message in the logs:

fit requests the Estimator to fit the best model (for the dataset and the best set of estimatorParamMap).

You should see the following INFO message in the logs:

In the end, fit creates a CrossValidatorModel (for the ID, the best model and the average metrics for every kFold) and copies parameters to it.

fit and Computing Metric for Training and Validation RDDs

fit computes metrics for every pair of training and validation RDDs (from kFold).

fit creates and persists training and validation datasets.

Tip
You can monitor the storage for persisting the datasets in web UI’s Storage tab.

fit Prints out the following DEBUG message to the logs

For every map in estimatorParamMaps parameter fit fits a model using the Estimator.

fit does the fitting in parallel per parallelism parameter.

Note
parallelism parameter defaults to 1, i.e. no parallelism for fitting models.
Note
fit unpersists the training data (per pair of training and validation RDDs) when all models have been trained.

fit requests the models to transform their respective validation datasets (with the corresponding parameters from estimatorParamMaps) and then requests the Evaluator to evaluate the transformed datasets.

fit prints out the following DEBUG message to the logs:

fit waits until all metrics are available and unpersists the validation dataset.

Creating CrossValidator Instance

CrossValidator takes the following when created:

  • Unique ID

Validating and Transforming Schema — transformSchema Method

Note
transformSchema is part of PipelineStage Contract.

transformSchema simply passes the call to transformSchemaImpl (that is shared between CrossValidator and TrainValidationSplit).

赞(0) 打赏
未经允许不得转载:spark技术分享 » CrossValidator — Model Tuning / Finding The Best Model
分享到: 更多 (0)

关注公众号:spark技术分享

联系我们联系我们

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏