LinearRegression
LinearRegression
is a Regressor that represents the linear regression algorithm in Machine Learning.
LinearRegression
belongs to org.apache.spark.ml.regression
package.
Tip
|
Read the scaladoc of LinearRegression. |
It expects org.apache.spark.mllib.linalg.Vector
as the input type of the column in a dataset and produces LinearRegressionModel.
1 2 3 4 5 6 |
import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression |
The acceptable parameters:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
scala> println(lr.explainParams) elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0) featuresCol: features column name (default: features) fitIntercept: whether to fit an intercept term (default: true) labelCol: label column name (default: label) maxIter: maximum number of iterations (>= 0) (default: 100) predictionCol: prediction column name (default: prediction) regParam: regularization parameter (>= 0) (default: 0.0) solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto' (default: auto) standardization: whether to standardize the training features before fitting the model (default: true) tol: the convergence tolerance for iterative algorithms (default: 1.0E-6) weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (default: ) |
LinearRegression Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint val data = (0.0 to 9.0 by 1) // create a collection of Doubles .map(n => (n, n)) // make it pairs .map { case (label, features) => LabeledPoint(label, Vectors.dense(features)) } // create labeled points of dense vectors .toDF // make it a DataFrame scala> data.show +-----+--------+ |label|features| +-----+--------+ | 0.0| [0.0]| | 1.0| [1.0]| | 2.0| [2.0]| | 3.0| [3.0]| | 4.0| [4.0]| | 5.0| [5.0]| | 6.0| [6.0]| | 7.0| [7.0]| | 8.0| [8.0]| | 9.0| [9.0]| +-----+--------+ import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression val model = lr.fit(data) scala> model.intercept res1: Double = 0.0 scala> model.coefficients res2: org.apache.spark.mllib.linalg.Vector = [1.0] // make predictions scala> val predictions = model.transform(data) predictions: org.apache.spark.sql.DataFrame = [label: double, features: vector ... 1 more field] scala> predictions.show +-----+--------+----------+ |label|features|prediction| +-----+--------+----------+ | 0.0| [0.0]| 0.0| | 1.0| [1.0]| 1.0| | 2.0| [2.0]| 2.0| | 3.0| [3.0]| 3.0| | 4.0| [4.0]| 4.0| | 5.0| [5.0]| 5.0| | 6.0| [6.0]| 6.0| | 7.0| [7.0]| 7.0| | 8.0| [8.0]| 8.0| | 9.0| [9.0]| 9.0| +-----+--------+----------+ import org.apache.spark.ml.evaluation.RegressionEvaluator // rmse is the default metric // We're explicit here for learning purposes val regEval = new RegressionEvaluator().setMetricName("rmse") val rmse = regEval.evaluate(predictions) scala> println(s"Root Mean Squared Error: $rmse") Root Mean Squared Error: 0.0 import org.apache.spark.mllib.linalg.DenseVector // NOTE Follow along to learn spark.ml-way (not RDD-way) predictions.rdd.map { r => (r(0).asInstanceOf[Double], r(1).asInstanceOf[DenseVector](0).toDouble, r(2).asInstanceOf[Double])) .toDF("label", "feature0", "prediction").show +-----+--------+----------+ |label|feature0|prediction| +-----+--------+----------+ | 0.0| 0.0| 0.0| | 1.0| 1.0| 1.0| | 2.0| 2.0| 2.0| | 3.0| 3.0| 3.0| | 4.0| 4.0| 4.0| | 5.0| 5.0| 5.0| | 6.0| 6.0| 6.0| | 7.0| 7.0| 7.0| | 8.0| 8.0| 8.0| | 9.0| 9.0| 9.0| +-----+--------+----------+ // Let's make it nicer to the eyes using a Scala case class scala> :pa // Entering paste mode (ctrl-D to finish) import org.apache.spark.sql.Row import org.apache.spark.mllib.linalg.DenseVector case class Prediction(label: Double, feature0: Double, prediction: Double) object Prediction { def apply(r: Row) = new Prediction( label = r(0).asInstanceOf[Double], feature0 = r(1).asInstanceOf[DenseVector](0).toDouble, prediction = r(2).asInstanceOf[Double]) } // Exiting paste mode, now interpreting. import org.apache.spark.sql.Row import org.apache.spark.mllib.linalg.DenseVector defined class Prediction defined object Prediction scala> predictions.rdd.map(Prediction.apply).toDF.show +-----+--------+----------+ |label|feature0|prediction| +-----+--------+----------+ | 0.0| 0.0| 0.0| | 1.0| 1.0| 1.0| | 2.0| 2.0| 2.0| | 3.0| 3.0| 3.0| | 4.0| 4.0| 4.0| | 5.0| 5.0| 5.0| | 6.0| 6.0| 6.0| | 7.0| 7.0| 7.0| | 8.0| 8.0| 8.0| | 9.0| 9.0| 9.0| +-----+--------+----------+ |
train
Method
1 2 3 4 5 |
train(dataset: DataFrame): LinearRegressionModel |
train
(protected) method of LinearRegression
expects a dataset
DataFrame with two columns:
-
label
of typeDoubleType
. -
features
of type Vector.
It returns LinearRegressionModel
.
It first counts the number of elements in features column (usually features
). The column has to be of mllib.linalg.Vector type (and can easily be prepared using HashingTF transformer).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
val spam = Seq( (0, "Hi Jacek. Wanna more SPAM? Best!"), (1, "This is SPAM. This is SPAM")).toDF("id", "email") import org.apache.spark.ml.feature.RegexTokenizer val regexTok = new RegexTokenizer() val spamTokens = regexTok.setInputCol("email").transform(spam) scala> spamTokens.show(false) +---+--------------------------------+---------------------------------------+ |id |email |regexTok_646b6bcc4548__output | +---+--------------------------------+---------------------------------------+ |0 |Hi Jacek. Wanna more SPAM? Best!|[hi, jacek., wanna, more, spam?, best!]| |1 |This is SPAM. This is SPAM |[this, is, spam., this, is, spam] | +---+--------------------------------+---------------------------------------+ import org.apache.spark.ml.feature.HashingTF val hashTF = new HashingTF() .setInputCol(regexTok.getOutputCol) .setOutputCol("features") .setNumFeatures(5000) val spamHashed = hashTF.transform(spamTokens) scala> spamHashed.select("email", "features").show(false) +--------------------------------+----------------------------------------------------------------+ |email |features | +--------------------------------+----------------------------------------------------------------+ |Hi Jacek. Wanna more SPAM? Best!|(5000,[2525,2943,3093,3166,3329,3980],[1.0,1.0,1.0,1.0,1.0,1.0])| |This is SPAM. This is SPAM |(5000,[1713,3149,3370,4070],[1.0,1.0,2.0,2.0]) | +--------------------------------+----------------------------------------------------------------+ // Create labeled datasets for spam (1) val spamLabeled = spamHashed.withColumn("label", lit(1d)) scala> spamLabeled.show +---+--------------------+-----------------------------+--------------------+-----+ | id| email|regexTok_646b6bcc4548__output| features|label| +---+--------------------+-----------------------------+--------------------+-----+ | 0|Hi Jacek. Wanna m...| [hi, jacek., wann...|(5000,[2525,2943,...| 1.0| | 1|This is SPAM. Thi...| [this, is, spam.,...|(5000,[1713,3149,...| 1.0| +---+--------------------+-----------------------------+--------------------+-----+ val regular = Seq( (2, "Hi Jacek. I hope this email finds you well. Spark up!"), (3, "Welcome to Apache Spark project")).toDF("id", "email") val regularTokens = regexTok.setInputCol("email").transform(regular) val regularHashed = hashTF.transform(regularTokens) // Create labeled datasets for non-spam regular emails (0) val regularLabeled = regularHashed.withColumn("label", lit(0d)) val training = regularLabeled.union(spamLabeled).cache scala> training.show +---+--------------------+-----------------------------+--------------------+-----+ | id| email|regexTok_646b6bcc4548__output| features|label| +---+--------------------+-----------------------------+--------------------+-----+ | 2|Hi Jacek. I hope ...| [hi, jacek., i, h...|(5000,[72,105,942...| 0.0| | 3|Welcome to Apache...| [welcome, to, apa...|(5000,[2894,3365,...| 0.0| | 0|Hi Jacek. Wanna m...| [hi, jacek., wann...|(5000,[2525,2943,...| 1.0| | 1|This is SPAM. Thi...| [this, is, spam.,...|(5000,[1713,3149,...| 1.0| +---+--------------------+-----------------------------+--------------------+-----+ import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression // the following calls train by the Predictor contract (see above) val lrModel = lr.fit(training) // Let's predict whether an email is a spam or not val email = Seq("Hi Jacek. you doing well? Bye!").toDF("email") val emailTokens = regexTok.setInputCol("email").transform(email) val emailHashed = hashTF.transform(emailTokens) scala> lrModel.transform(emailHashed).select("prediction").show +-----------------+ | prediction| +-----------------+ |0.563603440350882| +-----------------+ |