Estimator
Estimator is the contract in Spark MLlib for estimators that fit models to a dataset.
Estimator accepts parameters that you can set through dedicated setter methods upon creating an Estimator. You could also fit a model with extra parameters.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import org.apache.spark.ml.classification.LogisticRegression // Define parameters upon creating an Estimator val lr = new LogisticRegression(). setMaxIter(5). setRegParam(0.01) val training: DataFrame = ... val model1 = lr.fit(training) // Define parameters through fit import org.apache.spark.ml.param.ParamMap val customParams = ParamMap( lr.maxIter -> 10, lr.featuresCol -> "custom_features" ) val model2 = lr.fit(training, customParams) |
Estimator is a PipelineStage and so can be a part of a Pipeline.
Estimator Contract
|
1 2 3 4 5 6 7 8 9 10 11 |
package org.apache.spark.ml abstract class Estimator[M <: Model[M]] { // only required methods that have no implementation def fit(dataset: Dataset[_]): M def copy(extra: ParamMap): Estimator[M] } |
| Method | Description |
|---|---|
|
Used when… |
|
|
Used when… |
Fitting Model with Extra Parameters — fit Method
|
1 2 3 4 5 |
fit(dataset: Dataset[_], paramMap: ParamMap): M |
fit copies the extra paramMap and fits a model (of type M).
|
Note
|
fit is used mainly for model tuning to find the best model (using CrossValidator and TrainValidationSplit).
|
spark技术分享