ALSModel — Model for Predictions
|
Note
|
A Model in Spark MLlib is a Transformer that comes with a custom transform method.
|
When making prediction (i.e. executed), ALSModel…FIXME
ALSModel is created when:
ALSModel is a MLWritable.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
// The following spark-shell session is used to show // how ALSModel works under the covers // Mostly to learn how to work with the private ALSModel class // Use paste raw mode to copy the code // :paste -raw (or its shorter version :pa -raw) // BEGIN :pa -raw package org.apache.spark.ml import org.apache.spark.sql._ class MyALS(spark: SparkSession) { import spark.implicits._ val userFactors = Seq((0, Seq(0.3, 0.2))).toDF("id", "features") val itemFactors = Seq((0, Seq(0.3, 0.2))).toDF("id", "features") import org.apache.spark.ml.recommendation._ val alsModel = new ALSModel(uid = "uid", rank = 10, userFactors, itemFactors) } // END :pa -raw // Copy the following to spark-shell directly import org.apache.spark.ml._ val model = new MyALS(spark). alsModel. setUserCol("user"). setItemCol("item") import org.apache.spark.sql.types._ val mySchema = new StructType(). add($"user".float). add($"item".float) val transformedSchema = model.transformSchema(mySchema) scala> transformedSchema.printTreeString root |-- user: float (nullable = true) |-- item: float (nullable = true) |-- prediction: float (nullable = false) |
Making Predictions — transform Method
|
1 2 3 4 5 |
transform(dataset: Dataset[_]): DataFrame |
|
Note
|
transform is part of Transformer Contract.
|
Internally, transform validates the schema of the dataset.
transform left-joins the dataset with userFactors dataset (using userCol column of dataset and id column of userFactors).
|
Note
|
Left join takes two datasets and gives all the rows from the left side (of the join) combined with the corresponding row from the right side if available or
|
transform left-joins the dataset with itemFactors dataset (using itemCol column of dataset and id column of itemFactors).
transform makes predictions using the features columns of userFactors and itemFactors datasets (per every row in the left-joined dataset).
transform takes (selects) all the columns from the dataset and predictionCol with predictions.
Ultimately, transform drops rows containing null or NaN values for predictions if coldStartStrategy is drop.
|
Note
|
The default value of coldStartStrategy is nan that does not drop missing values from predictions column.
|
transformSchema Method
|
1 2 3 4 5 |
transformSchema(schema: StructType): StructType |
|
Note
|
transformSchema is part of Transformer Contract.
|
Internally, transform validates the schema of the dataset.
Creating ALSModel Instance
ALSModel takes the following when created:
ALSModel initializes the internal registries and counters.
Requesting sdot from BLAS — predict Internal Property
|
1 2 3 4 5 |
predict: UserDefinedFunction |
predict is a user-defined function (UDF) that takes two collections of float numbers and requests BLAS for sdot.
|
Caution
|
FIXME Read about com.github.fommil.netlib.BLAS.getInstance.sdot.
|
|
Note
|
predict is a mere wrapper of com.github.fommil.netlib.BLAS.
|
|
Note
|
predict is used exclusively when ALSModel is requested to transform.
|
Creating ALSModel with Extra Parameters — copy Method
|
1 2 3 4 5 |
copy(extra: ParamMap): ALSModel |
|
Note
|
copy is part of Model Contract.
|
copy creates a new ALSModel.
copy then copies extra parameters to the new ALSModel and sets the parent.
spark技术分享