LinearRegression
LinearRegression is a Regressor that represents the linear regression algorithm in Machine Learning.
LinearRegression belongs to org.apache.spark.ml.regression package.
|
Tip
|
Read the scaladoc of LinearRegression. |
It expects org.apache.spark.mllib.linalg.Vector as the input type of the column in a dataset and produces LinearRegressionModel.
|
1 2 3 4 5 6 |
import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression |
The acceptable parameters:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
scala> println(lr.explainParams) elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0) featuresCol: features column name (default: features) fitIntercept: whether to fit an intercept term (default: true) labelCol: label column name (default: label) maxIter: maximum number of iterations (>= 0) (default: 100) predictionCol: prediction column name (default: prediction) regParam: regularization parameter (>= 0) (default: 0.0) solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto' (default: auto) standardization: whether to standardize the training features before fitting the model (default: true) tol: the convergence tolerance for iterative algorithms (default: 1.0E-6) weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (default: ) |
LinearRegression Example
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint val data = (0.0 to 9.0 by 1) // create a collection of Doubles .map(n => (n, n)) // make it pairs .map { case (label, features) => LabeledPoint(label, Vectors.dense(features)) } // create labeled points of dense vectors .toDF // make it a DataFrame scala> data.show +-----+--------+ |label|features| +-----+--------+ | 0.0| [0.0]| | 1.0| [1.0]| | 2.0| [2.0]| | 3.0| [3.0]| | 4.0| [4.0]| | 5.0| [5.0]| | 6.0| [6.0]| | 7.0| [7.0]| | 8.0| [8.0]| | 9.0| [9.0]| +-----+--------+ import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression val model = lr.fit(data) scala> model.intercept res1: Double = 0.0 scala> model.coefficients res2: org.apache.spark.mllib.linalg.Vector = [1.0] // make predictions scala> val predictions = model.transform(data) predictions: org.apache.spark.sql.DataFrame = [label: double, features: vector ... 1 more field] scala> predictions.show +-----+--------+----------+ |label|features|prediction| +-----+--------+----------+ | 0.0| [0.0]| 0.0| | 1.0| [1.0]| 1.0| | 2.0| [2.0]| 2.0| | 3.0| [3.0]| 3.0| | 4.0| [4.0]| 4.0| | 5.0| [5.0]| 5.0| | 6.0| [6.0]| 6.0| | 7.0| [7.0]| 7.0| | 8.0| [8.0]| 8.0| | 9.0| [9.0]| 9.0| +-----+--------+----------+ import org.apache.spark.ml.evaluation.RegressionEvaluator // rmse is the default metric // We're explicit here for learning purposes val regEval = new RegressionEvaluator().setMetricName("rmse") val rmse = regEval.evaluate(predictions) scala> println(s"Root Mean Squared Error: $rmse") Root Mean Squared Error: 0.0 import org.apache.spark.mllib.linalg.DenseVector // NOTE Follow along to learn spark.ml-way (not RDD-way) predictions.rdd.map { r => (r(0).asInstanceOf[Double], r(1).asInstanceOf[DenseVector](0).toDouble, r(2).asInstanceOf[Double])) .toDF("label", "feature0", "prediction").show +-----+--------+----------+ |label|feature0|prediction| +-----+--------+----------+ | 0.0| 0.0| 0.0| | 1.0| 1.0| 1.0| | 2.0| 2.0| 2.0| | 3.0| 3.0| 3.0| | 4.0| 4.0| 4.0| | 5.0| 5.0| 5.0| | 6.0| 6.0| 6.0| | 7.0| 7.0| 7.0| | 8.0| 8.0| 8.0| | 9.0| 9.0| 9.0| +-----+--------+----------+ // Let's make it nicer to the eyes using a Scala case class scala> :pa // Entering paste mode (ctrl-D to finish) import org.apache.spark.sql.Row import org.apache.spark.mllib.linalg.DenseVector case class Prediction(label: Double, feature0: Double, prediction: Double) object Prediction { def apply(r: Row) = new Prediction( label = r(0).asInstanceOf[Double], feature0 = r(1).asInstanceOf[DenseVector](0).toDouble, prediction = r(2).asInstanceOf[Double]) } // Exiting paste mode, now interpreting. import org.apache.spark.sql.Row import org.apache.spark.mllib.linalg.DenseVector defined class Prediction defined object Prediction scala> predictions.rdd.map(Prediction.apply).toDF.show +-----+--------+----------+ |label|feature0|prediction| +-----+--------+----------+ | 0.0| 0.0| 0.0| | 1.0| 1.0| 1.0| | 2.0| 2.0| 2.0| | 3.0| 3.0| 3.0| | 4.0| 4.0| 4.0| | 5.0| 5.0| 5.0| | 6.0| 6.0| 6.0| | 7.0| 7.0| 7.0| | 8.0| 8.0| 8.0| | 9.0| 9.0| 9.0| +-----+--------+----------+ |
train Method
|
1 2 3 4 5 |
train(dataset: DataFrame): LinearRegressionModel |
train (protected) method of LinearRegression expects a dataset DataFrame with two columns:
-
labelof typeDoubleType. -
featuresof type Vector.
It returns LinearRegressionModel.
It first counts the number of elements in features column (usually features). The column has to be of mllib.linalg.Vector type (and can easily be prepared using HashingTF transformer).
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
val spam = Seq( (0, "Hi Jacek. Wanna more SPAM? Best!"), (1, "This is SPAM. This is SPAM")).toDF("id", "email") import org.apache.spark.ml.feature.RegexTokenizer val regexTok = new RegexTokenizer() val spamTokens = regexTok.setInputCol("email").transform(spam) scala> spamTokens.show(false) +---+--------------------------------+---------------------------------------+ |id |email |regexTok_646b6bcc4548__output | +---+--------------------------------+---------------------------------------+ |0 |Hi Jacek. Wanna more SPAM? Best!|[hi, jacek., wanna, more, spam?, best!]| |1 |This is SPAM. This is SPAM |[this, is, spam., this, is, spam] | +---+--------------------------------+---------------------------------------+ import org.apache.spark.ml.feature.HashingTF val hashTF = new HashingTF() .setInputCol(regexTok.getOutputCol) .setOutputCol("features") .setNumFeatures(5000) val spamHashed = hashTF.transform(spamTokens) scala> spamHashed.select("email", "features").show(false) +--------------------------------+----------------------------------------------------------------+ |email |features | +--------------------------------+----------------------------------------------------------------+ |Hi Jacek. Wanna more SPAM? Best!|(5000,[2525,2943,3093,3166,3329,3980],[1.0,1.0,1.0,1.0,1.0,1.0])| |This is SPAM. This is SPAM |(5000,[1713,3149,3370,4070],[1.0,1.0,2.0,2.0]) | +--------------------------------+----------------------------------------------------------------+ // Create labeled datasets for spam (1) val spamLabeled = spamHashed.withColumn("label", lit(1d)) scala> spamLabeled.show +---+--------------------+-----------------------------+--------------------+-----+ | id| email|regexTok_646b6bcc4548__output| features|label| +---+--------------------+-----------------------------+--------------------+-----+ | 0|Hi Jacek. Wanna m...| [hi, jacek., wann...|(5000,[2525,2943,...| 1.0| | 1|This is SPAM. Thi...| [this, is, spam.,...|(5000,[1713,3149,...| 1.0| +---+--------------------+-----------------------------+--------------------+-----+ val regular = Seq( (2, "Hi Jacek. I hope this email finds you well. Spark up!"), (3, "Welcome to Apache Spark project")).toDF("id", "email") val regularTokens = regexTok.setInputCol("email").transform(regular) val regularHashed = hashTF.transform(regularTokens) // Create labeled datasets for non-spam regular emails (0) val regularLabeled = regularHashed.withColumn("label", lit(0d)) val training = regularLabeled.union(spamLabeled).cache scala> training.show +---+--------------------+-----------------------------+--------------------+-----+ | id| email|regexTok_646b6bcc4548__output| features|label| +---+--------------------+-----------------------------+--------------------+-----+ | 2|Hi Jacek. I hope ...| [hi, jacek., i, h...|(5000,[72,105,942...| 0.0| | 3|Welcome to Apache...| [welcome, to, apa...|(5000,[2894,3365,...| 0.0| | 0|Hi Jacek. Wanna m...| [hi, jacek., wann...|(5000,[2525,2943,...| 1.0| | 1|This is SPAM. Thi...| [this, is, spam.,...|(5000,[1713,3149,...| 1.0| +---+--------------------+-----------------------------+--------------------+-----+ import org.apache.spark.ml.regression.LinearRegression val lr = new LinearRegression // the following calls train by the Predictor contract (see above) val lrModel = lr.fit(training) // Let's predict whether an email is a spam or not val email = Seq("Hi Jacek. you doing well? Bye!").toDF("email") val emailTokens = regexTok.setInputCol("email").transform(email) val emailHashed = hashTF.transform(emailTokens) scala> lrModel.transform(emailHashed).select("prediction").show +-----------------+ | prediction| +-----------------+ |0.563603440350882| +-----------------+ |
spark技术分享