ALSModel — Model for Predictions
Note
|
A Model in Spark MLlib is a Transformer that comes with a custom transform method.
|
When making prediction (i.e. executed), ALSModel
…FIXME
ALSModel
is created when:
ALSModel
is a MLWritable.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
// The following spark-shell session is used to show // how ALSModel works under the covers // Mostly to learn how to work with the private ALSModel class // Use paste raw mode to copy the code // :paste -raw (or its shorter version :pa -raw) // BEGIN :pa -raw package org.apache.spark.ml import org.apache.spark.sql._ class MyALS(spark: SparkSession) { import spark.implicits._ val userFactors = Seq((0, Seq(0.3, 0.2))).toDF("id", "features") val itemFactors = Seq((0, Seq(0.3, 0.2))).toDF("id", "features") import org.apache.spark.ml.recommendation._ val alsModel = new ALSModel(uid = "uid", rank = 10, userFactors, itemFactors) } // END :pa -raw // Copy the following to spark-shell directly import org.apache.spark.ml._ val model = new MyALS(spark). alsModel. setUserCol("user"). setItemCol("item") import org.apache.spark.sql.types._ val mySchema = new StructType(). add($"user".float). add($"item".float) val transformedSchema = model.transformSchema(mySchema) scala> transformedSchema.printTreeString root |-- user: float (nullable = true) |-- item: float (nullable = true) |-- prediction: float (nullable = false) |
Making Predictions — transform
Method
1 2 3 4 5 |
transform(dataset: Dataset[_]): DataFrame |
Note
|
transform is part of Transformer Contract.
|
Internally, transform
validates the schema of the dataset
.
transform
left-joins the dataset
with userFactors dataset (using userCol column of dataset
and id
column of userFactors).
Note
|
Left join takes two datasets and gives all the rows from the left side (of the join) combined with the corresponding row from the right side if available or
|
transform
left-joins the dataset
with itemFactors dataset (using itemCol column of dataset
and id
column of itemFactors).
transform
makes predictions using the features
columns of userFactors and itemFactors datasets (per every row in the left-joined dataset).
transform
takes (selects) all the columns from the dataset
and predictionCol with predictions.
Ultimately, transform
drops rows containing null
or NaN
values for predictions if coldStartStrategy is drop
.
Note
|
The default value of coldStartStrategy is nan that does not drop missing values from predictions column.
|
transformSchema
Method
1 2 3 4 5 |
transformSchema(schema: StructType): StructType |
Note
|
transformSchema is part of Transformer Contract.
|
Internally, transform
validates the schema of the dataset
.
Creating ALSModel Instance
ALSModel
takes the following when created:
ALSModel
initializes the internal registries and counters.
Requesting sdot from BLAS — predict
Internal Property
1 2 3 4 5 |
predict: UserDefinedFunction |
predict
is a user-defined function (UDF) that takes two collections of float numbers and requests BLAS for sdot
.
Caution
|
FIXME Read about com.github.fommil.netlib.BLAS.getInstance.sdot .
|
Note
|
predict is a mere wrapper of com.github.fommil.netlib.BLAS.
|
Note
|
predict is used exclusively when ALSModel is requested to transform.
|
Creating ALSModel with Extra Parameters — copy
Method
1 2 3 4 5 |
copy(extra: ParamMap): ALSModel |
Note
|
copy is part of Model Contract.
|
copy
creates a new ALSModel
.
copy
then copies extra parameters to the new ALSModel
and sets the parent.