Spark 2.1.0入门:决策树分类器

大数据学习路线图

【版权声明】博客内容由厦门大学数据库实验室拥有版权,未经允许,请勿转载!
[返回Spark教程首页]

一、方法简介

​ 决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

二、基本原理

​ 决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的剪枝。

(一)特征选择

​ 特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益(或信息增益比、基尼指数等),每次计算每个特征的信息增益,并比较它们的大小,选择信息增益最大(信息增益比最大、基尼指数最小)的特征。下面我们重点介绍一下特征选择的准则:信息增益。

​ 首先定义信息论中广泛使用的一个度量标准——熵(entropy),它是表示随机变量不确定性的度量。熵越大,随机变量的不确定性就越大。而信息增益(informational entropy)表示得知某一特征后使得信息的不确定性减少的程度。简单的说,一个属性的信息增益就是由于使用这个属性分割样例而导致的期望熵降低。信息增益、信息增益比和基尼指数的具体定义如下:

信息增益:特征A对训练数据集D的信息增益

    \[g(D,A)\]

,定义为集合D的经验熵

    \[H(D)\]

与特征A给定条件下D的经验条件熵

    \[H(D|A)\]

之差,即

    \[g(D,A)=H(D)-H(D|A)\]

信息增益比:特征A对训练数据集D的信息增益比

    \[g_R(D,A)\]

定义为其信息增益

    \[g(D,A)\]

与训练数据集D关于特征A的值的熵

    \[H_A(D)\]

之比,即

    \[g_R(D,A)=\frac{g(D,A)}{H_A(D)}\]

其中,

    \[H_A(D)=-\sum_{i=1}^{n}\frac{\left|D_i\right|}{\left|D\right|}log_2\frac{\left|D_i\right|}{\left|D\right|}\]

,n是特征A取值的个数。

基尼指数:分类问题中,假设有K个类,样本点属于第K类的概率为

    \[p_k\]

,则概率分布的基尼指数定义为

    \[Gini(p)=\sum_{k=1}^{K}p_k(1-p_k)=1-\sum_{k=1}^{K}p_k^2\]

(二)决策树的生成

​ 从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点,再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增均很小或没有特征可以选择为止,最后得到一个决策树。

​ 决策树需要有停止条件来终止其生长的过程。一般来说最低的条件是:当该节点下面的所有记录都属于同一类,或者当所有的记录属性都具有相同的值时。这两种条件是停止决策树的必要条件,也是最低的条件。在实际运用中一般希望决策树提前停止生长,限定叶节点包含的最低数据量,以防止由于过度生长造成的过拟合问题。

(三)决策树的剪枝

​ 决策树生成算法递归地产生决策树,直到不能继续下去为止。这样产生的树往往对训练数据的分类很准确,但对未知的测试数据的分类却没有那么准确,即出现过拟合现象。解决这个问题的办法是考虑决策树的复杂度,对已生成的决策树进行简化,这个过程称为剪枝。

​ 决策树的剪枝往往通过极小化决策树整体的损失函数来实现。一般来说,损失函数可以进行如下的定义:

    \[C_a(T)=C(T)+a\left|T\right|\]

​ 其中,T为任意子树,

    \[C(T)\]

为对训练数据的预测误差(如基尼指数),

    \[\left|T\right|\]

为子树的叶结点个数,

    \[a\ge0\]

为参数,

    \[C_a(T)\]

为参数是

    \[a\]

时的子树T的整体损失,参数

    \[a\]

权衡训练数据的拟合程度与模型的复杂度。对于固定的

    \[a\]

,一定存在使损失函数

    \[C_a(T)\]

最小的子树,将其表示为

    \[T_a\]

。当

    \[a\]

大的时候,最优子树

    \[T_a\]

偏小;当

    \[a\]

小的时候,最优子树

    \[T_a\]

偏大。

示例代码

​ 我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。决策树可以用于分类和回归,接下来我们将在代码中分别进行介绍。

1. 导入需要的包:
  1. import org.apache.spark.sql.SparkSession
  2. import org.apache.spark.ml.linalg.{Vector,Vectors}
  3. import org.apache.spark.ml.Pipeline
  4. import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
scala
2. 读取数据,简要分析:

​ 导入spark.implicits._,使其支持把一个RDD隐式转换为一个DataFrame。我们用case class定义一个schema:Iris,Iris就是我们需要的数据的结构;然后读取文本文件,第一个map把每行的数据用“,”隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;然后把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询;选出我们需要的数据后,我们可以把结果打印出来查看一下数据。

  1. scala> import spark.implicits._
  2. import spark.implicits._
  3.  
  4. scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
  5. defined class Iris
  6.  
  7. scala> val data = spark.sparkContext.textFile("file:///usr/local/spark/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble, p(3).toDouble),p(4).toString())).toDF()
  8. data: org.apache.spark.sql.DataFrame = [features: vector, label: string]
  9.  
  10. scala> data.createOrReplaceTempView("iris")
  11.  
  12. scala> val df = spark.sql("select * from iris")
  13. df: org.apache.spark.sql.DataFrame = [features: vector, label: string]
  14.  
  15. scala> df.map(t => t(1)+":"+t(0)).collect().foreach(println)
  16. Iris-setosa:[5.1,3.5,1.4,0.2]
  17. Iris-setosa:[4.9,3.0,1.4,0.2]
  18. Iris-setosa:[4.7,3.2,1.3,0.2]
  19. Iris-setosa:[4.6,3.1,1.5,0.2]
  20. Iris-setosa:[5.0,3.6,1.4,0.2]
  21. Iris-setosa:[5.4,3.9,1.7,0.4]
  22. Iris-setosa:[4.6,3.4,1.4,0.3]
  23.  
  24. ... ...
scala

3. 进一步处理特征和标签,以及数据分组:

  1. //分别获取标签列和特征列,进行索引,并进行了重命名。
  2. scala> val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol(
  3. "indexedLabel").fit(df)
  4. labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_107f7e530fa7
  5.  
  6. scala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutpu
  7. tCol("indexedFeatures").setMaxCategories(4).fit(df)
  8. featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_0649803dfa70
  9. //这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。
  10. scala> val labelConverter = new IndexToString().setInputCol("prediction").setOut
  11. putCol("predictedLabel").setLabels(labelIndexer.labels)
  12. labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_046182b2e571
  13. //接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。
  14. scala> val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
  15. trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
  16. testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
scala

4. 构建决策树分类模型:

  1. //导入所需要的包
  2. import org.apache.spark.ml.classification.DecisionTreeClassificationModel
  3. import org.apache.spark.ml.classification.DecisionTreeClassifier
  4. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  5. //训练决策树模型,这里我们可以通过setter的方法来设置决策树的参数,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。具体的可以设置的参数可以通过explainParams()来获取。
  6. scala> val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel
  7. ").setFeaturesCol("indexedFeatures")
  8. dtClassifier: org.apache.spark.ml.classification.DecisionTreeClassifier = dtc_029ea28aceb1
  9. //在pipeline中进行设置
  10. scala> val pipelinedClassifier = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
  11. pipelinedClassifier: org.apache.spark.ml.Pipeline = pipeline_a254dfd6dfb9
  12. //训练决策树模型
  13. scala> val modelClassifier = pipelinedClassifier.fit(trainingData)
  14. modelClassifier: org.apache.spark.ml.PipelineModel = pipeline_a254dfd6dfb9
  15. //进行预测
  16. scala> val predictionsClassifier = modelClassifier.transform(testData)
  17. predictionsClassifier: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]
  18. //查看部分预测的结果
  19. scala> predictionsClassifier.select("predictedLabel", "label", "features").show(20)
  20. +---------------+---------------+-----------------+
  21. | predictedLabel| label| features|
  22. +---------------+---------------+-----------------+
  23. | Iris-setosa| Iris-setosa|[4.4,2.9,1.4,0.2]|
  24. | Iris-setosa| Iris-setosa|[4.6,3.6,1.0,0.2]|
  25. | Iris-virginica|Iris-versicolor|[4.9,2.4,3.3,1.0]|
  26. | Iris-setosa| Iris-setosa|[4.9,3.1,1.5,0.1]|
  27. | Iris-setosa| Iris-setosa|[4.9,3.1,1.5,0.1]|
  28. | Iris-setosa| Iris-setosa|[5.0,3.5,1.6,0.6]|
  29. | Iris-setosa| Iris-setosa|[5.2,3.5,1.5,0.2]|
  30. | Iris-setosa| Iris-setosa|[5.2,4.1,1.5,0.1]|
  31. | Iris-setosa| Iris-setosa|[5.4,3.4,1.7,0.2]|
  32. | Iris-setosa| Iris-setosa|[5.4,3.7,1.5,0.2]|
  33. | Iris-setosa| Iris-setosa|[5.4,3.9,1.7,0.4]|
  34. |Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
  35. | Iris-setosa| Iris-setosa|[5.7,4.4,1.5,0.4]|
  36. | Iris-virginica|Iris-versicolor|[5.9,3.2,4.8,1.8]|
  37. |Iris-versicolor|Iris-versicolor|[6.1,2.8,4.0,1.3]|
  38. |Iris-versicolor|Iris-versicolor|[6.2,2.2,4.5,1.5]|
  39. | Iris-virginica|Iris-versicolor|[6.3,2.5,4.9,1.5]|
  40. |Iris-versicolor|Iris-versicolor|[6.3,3.3,4.7,1.6]|
  41. |Iris-versicolor|Iris-versicolor|[6.4,2.9,4.3,1.3]|
  42. |Iris-versicolor|Iris-versicolor|[6.5,2.8,4.6,1.5]|
  43. +---------------+---------------+-----------------+
  44. only showing top 20 rows
scala

5. 评估决策树分类模型:

  1. scala> val evaluatorClassifier = new MulticlassClassificationEvaluator().s
  2. etLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
  3. evaluatorClassifier: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_4abc19f3a54d
  4.  
  5. scala> val accuracy = evaluatorClassifier.evaluate(predictionsClassifier)
  6. accuracy: Double = 0.8648648648648649
  7.  
  8. scala> println("Test Error = " + (1.0 - accuracy))
  9. Test Error = 0.1351351351351351
  10.  
  11. scala> val treeModelClassifier = modelClassifier.stages(2).asInstanceOf[De
  12. cisionTreeClassificationModel]
  13. treeModelClassifier: org.apache.spark.ml.classification.DecisionTreeClassificati
  14. onModel = DecisionTreeClassificationModel (uid=dtc_029ea28aceb1) of depth 5 with 13 nodes
  15.  
  16. scala> println("Learned classification tree model:\n" + treeModelClassifier.toDebugString)
  17. Learned classification tree model:
  18. DecisionTreeClassificationModel (uid=dtc_029ea28aceb1) of depth 5 with 13 nodes
  19. If (feature 2 <= 1.9)
  20. Predict: 2.0
  21. Else (feature 2 > 1.9)
  22. If (feature 2 <= 4.7)
  23. If (feature 0 <= 4.9)
  24. Predict: 1.0
  25. Else (feature 0 > 4.9)
  26. Predict: 0.0
  27. Else (feature 2 > 4.7)
  28. If (feature 3 <= 1.6)
  29. If (feature 2 <= 4.8)
  30. Predict: 0.0
  31. Else (feature 2 > 4.8)
  32. If (feature 0 <= 6.0)
  33. Predict: 0.0
  34. Else (feature 0 > 6.0)
  35. Predict: 1.0
  36. Else (feature 3 > 1.6)
  37. Predict: 1.0
scala

​ 从上述结果可以看到模型的预测准确率为 0.8648648648648649以及训练的决策树模型结构。

6. 构建决策树回归模型:

  1. //导入所需要的包
  2. import org.apache.spark.ml.evaluation.RegressionEvaluator
  3. import org.apache.spark.ml.regression.DecisionTreeRegressionModel
  4. import org.apache.spark.ml.regression.DecisionTreeRegressor
  5. //训练决策树模型
  6. scala> val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel")
  7. .setFeaturesCol("indexedFeatures")
  8. dtRegressor: org.apache.spark.ml.regression.DecisionTreeRegressor = dtr_358e08c37f0c
  9. //在pipeline中进行设置
  10. scala> val pipelineRegressor = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtRegressor, labelConverter))
  11. pipelineRegressor: org.apache.spark.ml.Pipeline = pipeline_ae699675d015
  12. //训练决策树模型
  13. scala> val modelRegressor = pipelineRegressor.fit(trainingData)
  14. modelRegressor: org.apache.spark.ml.PipelineModel = pipeline_ae699675d015
  15. //进行预测
  16. scala> val predictionsRegressor = modelRegressor.transform(testData)
  17. predictionsRegressor: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 4 more fields]
  18. //查看部分预测结果
  19. scala> predictionsRegressor.select("predictedLabel", "label", "features").show(20)
  20. +---------------+---------------+-----------------+
  21. | predictedLabel| label| features|
  22. +---------------+---------------+-----------------+
  23. | Iris-setosa| Iris-setosa|[4.4,2.9,1.4,0.2]|
  24. | Iris-setosa| Iris-setosa|[4.6,3.6,1.0,0.2]|
  25. | Iris-virginica|Iris-versicolor|[4.9,2.4,3.3,1.0]|
  26. | Iris-setosa| Iris-setosa|[4.9,3.1,1.5,0.1]|
  27. | Iris-setosa| Iris-setosa|[4.9,3.1,1.5,0.1]|
  28. | Iris-setosa| Iris-setosa|[5.0,3.5,1.6,0.6]|
  29. | Iris-setosa| Iris-setosa|[5.2,3.5,1.5,0.2]|
  30. | Iris-setosa| Iris-setosa|[5.2,4.1,1.5,0.1]|
  31. | Iris-setosa| Iris-setosa|[5.4,3.4,1.7,0.2]|
  32. | Iris-setosa| Iris-setosa|[5.4,3.7,1.5,0.2]|
  33. | Iris-setosa| Iris-setosa|[5.4,3.9,1.7,0.4]|
  34. |Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
  35. | Iris-setosa| Iris-setosa|[5.7,4.4,1.5,0.4]|
  36. | Iris-virginica|Iris-versicolor|[5.9,3.2,4.8,1.8]|
  37. |Iris-versicolor|Iris-versicolor|[6.1,2.8,4.0,1.3]|
  38. |Iris-versicolor|Iris-versicolor|[6.2,2.2,4.5,1.5]|
  39. | Iris-virginica|Iris-versicolor|[6.3,2.5,4.9,1.5]|
  40. |Iris-versicolor|Iris-versicolor|[6.3,3.3,4.7,1.6]|
  41. |Iris-versicolor|Iris-versicolor|[6.4,2.9,4.3,1.3]|
  42. |Iris-versicolor|Iris-versicolor|[6.5,2.8,4.6,1.5]|
  43. +---------------+---------------+-----------------+
  44. only showing top 20 rows
scala

7. 评估决策树回归模型:

  1. scala> val evaluatorRegressor = new RegressionEvaluator().setLabelCol("ind
  2. exedLabel").setPredictionCol("prediction").setMetricName("rmse")
  3. evaluatorRegressor: org.apache.spark.ml.evaluation.RegressionEvaluator = regEval_425d2aeea2dd
  4.  
  5. scala> val rmse = evaluatorRegressor.evaluate(predictionsRegressor)
  6. rmse: Double = 0.3676073110469039
  7.  
  8. scala> println("Root Mean Squared Error (RMSE) on test data = " + rmse)
  9. Root Mean Squared Error (RMSE) on test data = 0.3676073110469039
  10.  
  11. scala> val treeModelRegressor = modelRegressor.stages(2).asInstanceOf[Deci
  12. sionTreeRegressionModel]
  13. treeModelRegressor: org.apache.spark.ml.regression.DecisionTreeRegressionModel =
  14. DecisionTreeRegressionModel (uid=dtr_358e08c37f0c) of depth 5 with 13 nodes
  15.  
  16. scala> println("Learned regression tree model:\n" + treeModelRegressor.toDebugString)
  17. Learned regression tree model:
  18. DecisionTreeRegressionModel (uid=dtr_358e08c37f0c) of depth 5 with 13 nodes
  19. If (feature 2 <= 1.9)
  20. Predict: 2.0
  21. Else (feature 2 > 1.9)
  22. If (feature 2 <= 4.7)
  23. If (feature 0 <= 4.9)
  24. Predict: 1.0
  25. Else (feature 0 > 4.9)
  26. Predict: 0.0
  27. Else (feature 2 > 4.7)
  28. If (feature 3 <= 1.6)
  29. If (feature 2 <= 4.8)
  30. Predict: 0.0
  31. Else (feature 2 > 4.8)
  32. If (feature 0 <= 6.0)
  33. Predict: 0.5
  34. Else (feature 0 > 6.0)
  35. Predict: 1.0
  36. Else (feature 3 > 1.6)
  37. Predict: 1.0
scala

从上述结果可以看到模型的标准误差为 0.3676073110469039以及训练的决策树模型结构。

子雨大数据之Spark入门
扫一扫访问本博客