ML Persistence — Saving and Loading Models and Pipelines
They allow you to save and load models despite the languages — Scala, Java, Python or R — they have been saved in and loaded later on.
MLWriter
MLWriter
abstract class comes with save(path: String)
method to save a ML component to a given path
.
1 2 3 4 5 |
save(path: String): Unit |
It comes with another (chainable) method overwrite
to overwrite the output path if it already exists.
1 2 3 4 5 |
overwrite(): this.type |
The component is saved into a JSON file (see MLWriter Example section below).
Tip
|
Enable Add the following line to
Refer to Logging. |
Caution
|
FIXME The logging doesn’t work and overwriting does not print out INFO message to the logs 🙁 |
MLWriter Example
1 2 3 4 5 6 7 |
import org.apache.spark.ml._ val pipeline = new Pipeline().setStages(Array.empty[PipelineStage]) pipeline.write.overwrite.save("sample-pipeline") |
The result of save
for “unfitted” pipeline is a JSON file for metadata (as shown below).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
$ cat sample-pipeline/metadata/part-00000 | jq { "class": "org.apache.spark.ml.Pipeline", "timestamp": 1472747720477, "sparkVersion": "2.1.0-SNAPSHOT", "uid": "pipeline_181c90b15d65", "paramMap": { "stageUids": [] } } |
The result of save
for pipeline model is a JSON file for metadata while Parquet for model data, e.g. coefficients.
1 2 3 4 5 6 |
val model = pipeline.fit(training) model.write.save("sample-model") |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
$ cat sample-model/metadata/part-00000 | jq { "class": "org.apache.spark.ml.PipelineModel", "timestamp": 1472748168005, "sparkVersion": "2.1.0-SNAPSHOT", "uid": "pipeline_3ed598da1c4b", "paramMap": { "stageUids": [ "regexTok_bf73e7c36e22", "hashingTF_ebece38da130", "logreg_819864aa7120" ] } } $ tree sample-model/stages/ sample-model/stages/ |-- 0_regexTok_bf73e7c36e22 | `-- metadata | |-- _SUCCESS | `-- part-00000 |-- 1_hashingTF_ebece38da130 | `-- metadata | |-- _SUCCESS | `-- part-00000 `-- 2_logreg_819864aa7120 |-- data | |-- _SUCCESS | `-- part-r-00000-56423674-0208-4768-9d83-2e356ac6a8d2.snappy.parquet `-- metadata |-- _SUCCESS `-- part-00000 7 directories, 8 files |
MLReader
MLReader
abstract class comes with load(path: String)
method to load
a ML component from a given path
.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import org.apache.spark.ml._ val pipeline = Pipeline.read.load("sample-pipeline") scala> val stageCount = pipeline.getStages.size stageCount: Int = 0 val pipelineModel = PipelineModel.read.load("sample-model") scala> pipelineModel.stages res1: Array[org.apache.spark.ml.Transformer] = Array(regexTok_bf73e7c36e22, hashingTF_ebece38da130, logreg_819864aa7120) |