Estimator
Estimator
is the contract in Spark MLlib for estimators that fit models to a dataset.
Estimator
accepts parameters that you can set through dedicated setter methods upon creating an Estimator
. You could also fit a model with extra parameters.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import org.apache.spark.ml.classification.LogisticRegression // Define parameters upon creating an Estimator val lr = new LogisticRegression(). setMaxIter(5). setRegParam(0.01) val training: DataFrame = ... val model1 = lr.fit(training) // Define parameters through fit import org.apache.spark.ml.param.ParamMap val customParams = ParamMap( lr.maxIter -> 10, lr.featuresCol -> "custom_features" ) val model2 = lr.fit(training, customParams) |
Estimator
is a PipelineStage and so can be a part of a Pipeline.
Estimator Contract
1 2 3 4 5 6 7 8 9 10 11 |
package org.apache.spark.ml abstract class Estimator[M <: Model[M]] { // only required methods that have no implementation def fit(dataset: Dataset[_]): M def copy(extra: ParamMap): Estimator[M] } |
Method | Description |
---|---|
Used when… |
|
Used when… |
Fitting Model with Extra Parameters — fit
Method
1 2 3 4 5 |
fit(dataset: Dataset[_], paramMap: ParamMap): M |
fit
copies the extra paramMap
and fits a model (of type M
).
Note
|
fit is used mainly for model tuning to find the best model (using CrossValidator and TrainValidationSplit).
|