【版权声明】博客内容由厦门大学数据库实验室拥有版权,未经允许,请勿转载!
[返回Spark教程首页]
## 模型选择和超参数调整
在机器学习中非常重要的任务就是模型选择,或者使用数据来找到具体问题的最佳的模型和参数,这个过程也叫做调试(Tuning)。调试可以在独立的估计器中完成(如逻辑斯蒂回归),也可以在包含多样算法、特征工程和其他步骤的工作流中完成。用户应该一次性调优整个工作流,而不是独立的调整PipeLine中的每个组成部分。
1、 交叉验证和训练-验证切分
MLlib支持交叉验证(CrossValidator)和训练验证分割(TrainValidationSplit)两个模型选择工具。使用这些工具要求包含如下对象:
1.估计器:待调试的算法或管线。
2.一系列参数表(ParamMaps):可选参数,也叫做“参数网格”搜索空间。
3.评估器:评估模型拟合程度的准则或方法。
模型选择工具工作原理如下:
1.将输入数据划分为训练数据和测试数据。
2. 对于每个(训练,测试)对,遍历一组ParamMaps。用每一个ParamMap参数来拟合估计器,得到训练后的模型,再使用评估器来评估模型表现。
3.选择性能表现最优模型对应参数表。
更具体的,交叉验证CrossValidato将数据集切分成k折叠数据集合,并被分别用于训练和测试。例如,k=3时,CrossValidator会生成3个(训练数据,测试数据)对,每一个数据对的训练数据占2/3,测试数据占1/3。为了评估一个ParamMap,CrossValidator 会计算这3个不同的(训练,测试)数据集对在Estimator拟合出的模型上的平均评估指标。在找出最好的ParamMap后,CrossValidator 会使用这个ParamMap和整个的数据集来重新拟合Estimator。也就是说通过交叉验证找到最佳的ParamMap,利用此ParamMap在整个训练集上可以训练(fit)出一个泛化能力强,误差相对小的的最佳模型。
交叉验证的代价比较高昂,为此Spark也为超参数调优提供了训练-验证切分TrainValidationSplit。TrainValidationSplit创建单一的(训练,测试)数据集对。它使用trainRatio参数将数据集切分成两部分。例如,当设置trainRatio=0.75时,TrainValidationSplit将会将数据切分75%作为数据集,25%作为验证集,来生成训练、测试集对,并最终使用最好的ParamMap和完整的数据集来拟合评估器。相对于CrossValidator对每一个参数进行k次评估,TrainValidationSplit只对每个参数组合评估1次。因此它的评估代价没有这么高,但是当训练数据集不够大的时候其结果相对不够可信。
2、使用交叉验证进行模型选择
使用CrossValidator的代价可能会异常的高。然而,对比启发式的手动调优,这是选择参数的行之有效的方法。下面示例示范了使用CrossValidator从整个网格的参数中选择合适的参数。
首先,导入必要的包:
import org.apache.spark.ml.linalg.{Vector,Vectors}
import spark.implicits._
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.Row
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.classification.{LogisticRegression,LogisticRegressionModel}
import org.apache.spark.ml.{Pipeline,PipelineModel}
接下来,读取Irisi数据集,分别获取标签列和特征列,进行索引、重命名,并设置机器学习工作流。交叉验证在把原始数据集分割为训练集与测试集。值得注意的是,只有训练集才可以用在模型的训练过程中,测试集则必须在模型完成之后才被用来评估模型优劣的依据。此外,训练集中样本数量必须够多,一般至少大于总样本数的 50%,且两组子集必须从完整集合中均匀取样。
scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
scala> val data = spark.sparkContext.textFile("file:///usr/local/spark/iris.txt").map(\_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble, p(3).toDouble), p(4).toString())).toDF()
scala>val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
scala>val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
scala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(data)
scala> val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(50)
scala>val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
scala> val lrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, lr, labelConverter))
可以使用ParamGridBuilder方便构造参数网格。其中regParam参数定义规范化项的权重;elasticNetParam是Elastic net 参数,取值介于0和1之间。elasticNetParam设置2个值,regParam设置3个值。最终将有(3 * 2) = 6个不同的模型将被训练。
scala> val paramGrid = new ParamGridBuilder().
addGrid(lr.elasticNetParam, Array(0.2,0.8)).
addGrid(lr.regParam, Array(0.01, 0.1, 0.5)).
build()
paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
Array({
logreg_cd4ae130834c-elasticNetParam: 0.2,
logreg_cd4ae130834c-regParam: 0.01
}, {
logreg_cd4ae130834c-elasticNetParam: 0.2,
logreg_cd4ae130834c-regParam: 0.1
}, {
logreg_cd4ae130834c-elasticNetParam: 0.2,
logreg_cd4ae130834c-regParam: 0.5
}, {
logreg_cd4ae130834c-elasticNetParam: 0.8,
logreg_cd4ae130834c-regParam: 0.01
}, {
logreg_cd4ae130834c-elasticNetParam: 0.8,
logreg_cd4ae130834c-regParam: 0.1
}, {
logreg_cd4ae130834c-elasticNetParam: 0.8,
logreg_cd4ae130834c-regParam: 0.5
})
再接下来,构建针对整个机器学习工作流的交叉验证类,定义验证模型、参数网格,以及数据集的折叠数,并调用fit方法进行模型训练。其中,对于回归问题评估器可选择RegressionEvaluator,二值数据可选择BinaryClassificationEvaluator,多分类问题可选择MulticlassClassificationEvaluator。评估器里默认的评估准则可通过setMetricName方法重写。
scala> val cv = new CrossValidator().
setEstimator(lrPipeline).
setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).
setEstimatorParamMaps(paramGrid).
setNumFolds(3) // Use 3+ in practice
cv: org.apache.spark.ml.tuning.CrossValidator = cv_6a92251c1281
scala>val cvModel = cv.fit(trainingData)
cvModel: org.apache.spark.ml.tuning.CrossValidatorModel = cv_6a92251c1281
接下来,调动transform方法对测试数据进行预测,并打印结果及精度。
scala> val lrPredictions=cvModel.transform(testData)
lrPredictions: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]
scala> lrPredictions.select("predictedLabel", "label", "features", "probability").
collect().
foreach{
case Row(predictedLabel: String, label:String,features:Vector, prob:Vector) =>
println(s"($label, $features) --> prob=$prob, predicted Label=$predictedLabel")
}
(Iris-setosa, [4.4,2.9,1.4,0.2]) --> prob=[0.036318343660463034,1.5386309810869498E-4,0.9635277932414283], predicted Label=Iris-setosa
(Iris-setosa, [4.4,3.0,1.3,0.2]) --> prob=[0.02309982419551813,8.133697579147449E-5,0.9768188388286903], predicted Label=Iris-setosa
(Iris-setosa, [4.5,2.3,1.3,0.3]) --> prob=[0.295171926064505,0.002812152256808639,0.7020159216786864], predicted Label=Iris-setosa
(Iris-setosa, [4.7,3.2,1.3,0.2]) --> prob=[0.01621563015044659,3.621933015640408E-5,0.983748150519397], predicted Label=Iris-setosa
(Iris-versicolor, [4.9,2.4,3.3,1.0]) --> prob=[0.7706219468339301,0.175414278336521,0.05396377482954888], predicted Label=Iris-versicolor
(Iris-setosa, [4.9,3.0,1.4,0.2]) --> prob=[0.04620242417577521,1.1860104379361763E-4,0.9536789747804313], predicted Label=Iris-setosa
(Iris-setosa, [4.9,3.1,1.5,0.1]) --> prob=[0.029622624413386225,5.4310661363472946E-5,0.9703230649252502], predicted Label=Iris-setosa
(Iris-setosa, [5.0,3.0,1.6,0.2]) --> prob=[0.06282360073107408,1.7313631637921477E-4,0.9370032629525467], predicted Label=Iris-setosa
(Iris-setosa, [5.1,3.4,1.5,0.2]) --> prob=[0.015651738921858543,2.380882798692269E-5,0.9843244522501545], predicted Label=Iris-setosa
(Iris-setosa, [5.1,3.7,1.5,0.4]) --> prob=[0.007822147912199168,1.559101076849462E-5,0.9921622610770324], predicted Label=Iris-setosa
(Iris-setosa, [5.1,3.8,1.9,0.4]) --> prob=[0.008122424007330846,1.9535702563476317E-5,0.9918580402901056], predicted Label=Iris-setosa
(Iris-versicolor, [5.2,2.7,3.9,1.4]) --> prob=[0.5850022165210945,0.39398916377407905,0.021008619704826315], predicted Label=Iris-versicolor
(Iris-versicolor, [5.6,2.9,3.6,1.3]) --> prob=[0.7639756559493451,0.17866855122913236,0.057355792821522394], predicted Label=Iris-versicolor
(Iris-versicolor, [5.6,3.0,4.5,1.5]) --> prob=[0.5627254194721568,0.4207169921382895,0.01655758838955361], predicted Label=Iris-versicolor
(Iris-versicolor, [5.7,2.8,4.5,1.3]) --> prob=[0.670629860707028,0.31690657095234,0.012463568340632149], predicted Label=Iris-versicolor
(Iris-setosa, [5.7,3.8,1.7,0.3]) --> prob=[0.011381900655107215,1.0924892262729634E-5,0.9886071744526301], predicted Label=Iris-setosa
(Iris-setosa, [5.7,4.4,1.5,0.4]) --> prob=[0.0012784168878862666,7.362791069618495E-7,0.9987208468330069], predicted Label=Iris-setosa
(Iris-versicolor, [5.9,3.2,4.8,1.8]) --> prob=[0.40332621706836463,0.5897121822337784,0.006961600697856969], predicted Label=Iris-virginica
(Iris-versicolor, [6.1,2.8,4.7,1.2]) --> prob=[0.7642917289988874,0.22707120620275661,0.008637064798356008], predicted Label=Iris-versicolor
(Iris-versicolor, [6.3,3.3,4.7,1.6]) --> prob=[0.6613043715486516,0.3222092638684948,0.016486364582853606], predicted Label=Iris-versicolor
(Iris-versicolor, [6.5,2.8,4.6,1.5]) --> prob=[0.6659921766014542,0.33119455776446366,0.002813265634082092], predicted Label=Iris-versicolor
(Iris-versicolor, [6.7,3.1,4.4,1.4]) --> prob=[0.826784337717256,0.16104413333871842,0.012171528944025676], predicted Label=Iris-versicolor
(Iris-versicolor, [6.9,3.1,4.9,1.5]) --> prob=[0.7504910784980305,0.2452274221374472,0.004281499364522243], predicted Label=Iris-versicolor
(Iris-versicolor, [7.0,3.2,4.7,1.4]) --> prob=[0.8450186951165065,0.14582937179001068,0.009151933093482754], predicted Label=Iris-versicolor
(Iris-virginica, [4.9,2.5,4.5,1.7]) --> prob=[0.19602952352157307,0.8024877940772516,0.0014826824011754125], predicted Label=Iris-virginica
(Iris-versicolor, [5.0,2.3,3.3,1.0]) --> prob=[0.7824409819608887,0.18394641034599954,0.03361260769311175], predicted Label=Iris-versicolor
(Iris-versicolor, [5.4,3.0,4.5,1.5]) --> prob=[0.5230870971614001,0.457218404527316,0.019694498311283788], predicted Label=Iris-versicolor
(Iris-versicolor, [5.5,2.5,4.0,1.3]) --> prob=[0.6480542975926028,0.34341910441153806,0.008526597995859177], predicted Label=Iris-versicolor
(Iris-versicolor, [5.7,2.9,4.2,1.3]) --> prob=[0.7274955070238317,0.24613869316174955,0.026365799814418774], predicted Label=Iris-versicolor
(Iris-versicolor, [5.8,2.7,3.9,1.2]) --> prob=[0.7959248438731429,0.1838335293489395,0.020241626777917736], predicted Label=Iris-versicolor
(Iris-virginica, [5.8,2.7,5.1,1.9]) --> prob=[0.17728792103622937,0.8223741362776862,3.379426860843727E-4], predicted Label=Iris-virginica
(Iris-virginica, [5.8,2.7,5.1,1.9]) --> prob=[0.17728792103622937,0.8223741362776862,3.379426860843727E-4], predicted Label=Iris-virginica
(Iris-virginica, [6.0,3.0,4.8,1.8]) --> prob=[0.3712388290768754,0.626033320380444,0.0027278505426806066], predicted Label=Iris-virginica
(Iris-virginica, [6.1,2.6,5.6,1.4]) --> prob=[0.4302596984529376,0.5691034839960433,6.368175510190741E-4], predicted Label=Iris-virginica
(Iris-versicolor, [6.2,2.9,4.3,1.3]) --> prob=[0.790921543793271,0.19508657422821454,0.013991881978514588], predicted Label=Iris-versicolor
(Iris-virginica, [6.3,2.9,5.6,1.8]) --> prob=[0.269783998879049,0.7297918049580144,4.2419616293655883E-4], predicted Label=Iris-virginica
(Iris-virginica, [6.3,3.3,6.0,2.5]) --> prob=[0.04940870314243357,0.9505350392733197,5.625758424661494E-5], predicted Label=Iris-virginica
(Iris-virginica, [6.4,3.2,5.3,2.3]) --> prob=[0.13399750253695122,0.8657212087015382,2.8128876151050963E-4], predicted Label=Iris-virginica
(Iris-virginica, [6.5,3.2,5.1,2.0]) --> prob=[0.3233819828627952,0.6752911209065505,0.001326896230654286], predicted Label=Iris-virginica
(Iris-virginica, [6.8,3.2,5.9,2.3]) --> prob=[0.11912311163560521,0.8807934897136858,8.339865070901092E-5], predicted Label=Iris-virginica
(Iris-virginica, [7.7,2.6,6.9,2.3]) --> prob=[0.06264205346355847,0.9373573573652129,5.89171228575748E-7], predicted Label=Iris-virginica
(Iris-virginica, [7.7,2.8,6.7,2.0]) --> prob=[0.19254820234929734,0.8074434825798231,8.315070879605122E-6], predicted Label=Iris-virginica
(Iris-virginica, [7.7,3.0,6.1,2.3]) --> prob=[0.15875628538237757,0.8412292887149714,1.4425902651179938E-5], predicted Label=Iris-virginica
(Iris-virginica, [7.7,3.8,6.7,2.2]) --> prob=[0.2826285310484042,0.7170538184653444,3.176504862515162E-4], predicted Label=Iris-virginica
scala> val evaluator = new MulticlassClassificationEvaluator().
setLabelCol("indexedLabel").
setPredictionCol("prediction")
evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_d56fa096e993
scala> val lrAccuracy = evaluator.evaluate(lrPredictions)
lrAccuracy: Double = 0.9773399014778326
最后,还可以获取最优的逻辑斯蒂回归模型,并查看其具体的参数:
scala> val bestModel= cvModel.bestModel.asInstanceOf[PipelineModel]
bestModel: org.apache.spark.ml.PipelineModel = pipeline_656a51f08dc4
scala> val lrModel = bestModel.stages(2).
asInstanceOf[LogisticRegressionModel]
lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_b3490372f4dd
scala> println("Coefficients: " + lrModel.coefficientMatrix + "Intercept: "+lrModel.interceptVector+ "numClasses: "+lrModel.numClasses+"numFeatures: "+lrModel.numFeatures)
Coefficients: 0.8354793833288098 -0.8387183954210289 0.1081832035508958 -0.0
0.054256275295321135 -1.9426755118201342 0.8538718077791957 3.007240263438064
-0.3972137619684629 2.8149119052091205 -0.9001646661509906 -1.9726757956849224 Intercept: [0.18591575015360287,-0.2846540294397142,0.09873827928611131]numClasses: 3numFeatures: 4
scala> lrModel.explainParam(lrModel.regParam)
res8: String = regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
scala> lrModel.explainParam(lrModel.elasticNetParam)
res9: String = elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.2)
可以看出,对于参数网格,其最优参数取值是regParam=0.01,elasticNetParam=0.2。