Spark2.1.0入门:逻辑斯蒂回归分类器(Python版)

大数据学习路线图

【版权声明】博客内容由厦门大学数据库实验室拥有版权,未经允许,请勿转载!

[返回Spark教程首页]
推荐纸质教材:林子雨、郑海山、赖永炫编著《Spark编程基础(Python版)》

逻辑斯蒂回归

方法简介

​ 逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型。logistic回归的因变量可以是二分类的,也可以是多分类的。

基本原理

logistic分布

​ 设X是连续随机变量,X服从logistic分布是指X具有下列分布函数和密度函数:

    \[F(x)=P(x \le x)=\frac 1 {1+e^{-(x-\mu)/\gamma}}\]

    \[f(x)=F^{'}(x)=\frac {e^{-(x-\mu)/\gamma}} {\gamma(1+e^{-(x-\mu)/\gamma})^2}\]

​ 其中,\mu为位置参数,\gamma为形状参数。

f(x)

    \[F(x)\]

图像如下,其中分布函数是以

    \[(\mu, \frac 1 2)\]

为中心对阵,

    \[\gamma\]

$越小曲线变化越快。

二项logistic回归模型:

​ 二项logistic回归模型如下:

    \[P(Y=1|x)=\frac {exp(w \cdot x + b)} {1 + exp(w \cdot x + b)}\]

    \[P(Y=0|x)=\frac {1} {1 + exp(w \cdot x + b)}\]

​ 其中,

    \[x \in R^n\]

是输入,

    \[Y \in {0,1}\]

是输出,w称为权值向量,b称为偏置,

    \[w \cdot x\]

为w和x的内积。

参数估计

​ 假设:

    \[P(Y=1|x)=\pi (x), \quad P(Y=0|x)=1-\pi (x)\]

​ 则似然函数为:

    \[\prod_{i=1}^N [\pi (x_i)]^{y_i} [1 - \pi(x_i)]^{1-y_i}\]

​ 求对数似然函数:

    \[L(w) = \sum_{i=1}^N [y_i \log{\pi(x_i)} + (1-y_i) \log{(1 - \pi(x_i)})]\]

    \[\sum_{i=1}^N [y_i \log{\frac {\pi (x_i)} {1 - \pi(x_i)}} + \log{(1 - \pi(x_i)})]=\sum_{i=1}^N [y_i \log{\frac {\pi (x_i)} {1 - \pi(x_i)}} + \log{(1 - \pi(x_i)})]\]

​ 从而对

    \[L(w)\]

求极大值,得到w的估计值。求极值的方法可以是梯度下降法,梯度上升法等。

示例代码

​ 我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。为了便于理解,我们这里主要用后两个属性(花瓣的长度和宽度)来进行分类。目前 spark.ml 中支持二分类和多分类,我们将分别从“用二项逻辑斯蒂回归来解决二分类问题”、“用多项逻辑斯蒂回归来解决二分类问题”、“用多项逻辑斯蒂回归来解决多分类问题”三个方面进行分析。

用二项逻辑斯蒂回归解决 二分类 问题

首先我们先取其中的后两类数据,用二项逻辑斯蒂回归进行二分类分析。

1. 导入需要的包:
  1. from pyspark.sql import Row,functions
  2. from pyspark.ml.linalg import Vector,Vectors
  3. from pyspark.ml.evaluation import MulticlassClassificationEvaluator
  4. from pyspark.ml import Pipeline
  5. from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer,HashingTF, Tokenizer
  6. from pyspark.ml.classification import LogisticRegression,LogisticRegressionModel,BinaryLogisticRegressionSummary, LogisticRegression
Python
2. 读取数据,简要分析:

​ 我们定制一个函数,来返回一个指定的数据,然后读取文本文件,第一个map把每行的数据用“,”隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;最后调用show()方法来查看一下部分数据。

  1. def f(x):
  2. rel = {}
  3. rel['features'] = Vectors.dense(float(x[0]),float(x[1]),float(x[2]),float(x[3]))
  4. rel['label'] = str(x[4])
  5. return rel
  6.  
  7. data = spark.sparkContext.textFile("file:///usr/local/spark/iris.txt").map(lambda line: line.split(',')).map(lambda p: Row(**f(p))).toDF()
  8.  
  9. +-----------------+-----------+
  10. | features| label|
  11. +-----------------+-----------+
  12. |[5.1,3.5,1.4,0.2]|Iris-setosa|
  13. |[4.9,3.0,1.4,0.2]|Iris-setosa|
  14. |[4.7,3.2,1.3,0.2]|Iris-setosa|
  15. |[4.6,3.1,1.5,0.2]|Iris-setosa|
  16. |[5.0,3.6,1.4,0.2]|Iris-setosa|
  17. |[5.4,3.9,1.7,0.4]|Iris-setosa|
  18. |[4.6,3.4,1.4,0.3]|Iris-setosa|
  19. |[5.0,3.4,1.5,0.2]|Iris-setosa|
  20. |[4.4,2.9,1.4,0.2]|Iris-setosa|
  21. |[4.9,3.1,1.5,0.1]|Iris-setosa|
  22. |[5.4,3.7,1.5,0.2]|Iris-setosa|
  23. |[4.8,3.4,1.6,0.2]|Iris-setosa|
  24. |[4.8,3.0,1.4,0.1]|Iris-setosa|
  25. |[4.3,3.0,1.1,0.1]|Iris-setosa|
  26. |[5.8,4.0,1.2,0.2]|Iris-setosa|
  27. |[5.7,4.4,1.5,0.4]|Iris-setosa|
  28. |[5.4,3.9,1.3,0.4]|Iris-setosa|
  29. |[5.1,3.5,1.4,0.3]|Iris-setosa|
  30. |[5.7,3.8,1.7,0.3]|Iris-setosa|
  31. |[5.1,3.8,1.5,0.3]|Iris-setosa|
  32. +-----------------+-----------+
  33. only showing top 20 rows
Python

​ 因为我们现在处理的是2分类问题,所以我们不需要全部的3类数据,我们要从中选出两类的数据。这里首先把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询,比如我们这里选出了所有不属于“Iris-setosa”类别的数据;选出我们需要的数据后,我们可以把结果打印出来看一下,这时就已经没有“Iris-setosa”类别的数据。

  1. data.createOrReplaceTempView("iris")
  2. df = spark.sql("select * from iris where label != 'Iris-setosa'")
  3. rel = df.rdd.map(lambda t : str(t[1])+":"+str(t[0])).collect()
  4. for item in rel:
  5. print(item)
  6.  
  7. Iris-versicolor:[7.0,3.2,4.7,1.4]
  8. Iris-versicolor:[6.4,3.2,4.5,1.5]
  9. Iris-versicolor:[6.9,3.1,4.9,1.5]
  10. Iris-versicolor:[5.5,2.3,4.0,1.3]
  11. Iris-versicolor:[6.5,2.8,4.6,1.5]
  12. Iris-versicolor:[5.7,2.8,4.5,1.3]
  13. Iris-versicolor:[6.3,3.3,4.7,1.6]
  14. Iris-versicolor:[4.9,2.4,3.3,1.0]
  15. Iris-versicolor:[6.6,2.9,4.6,1.3]
  16. Iris-versicolor:[5.2,2.7,3.9,1.4]
  17. Iris-versicolor:[5.0,2.0,3.5,1.0]
  18. Iris-versicolor:[5.9,3.0,4.2,1.5]
  19. Iris-versicolor:[6.0,2.2,4.0,1.0]
  20. Iris-versicolor:[6.1,2.9,4.7,1.4]
  21. Iris-versicolor:[5.6,2.9,3.6,1.3]
  22. Iris-versicolor:[6.7,3.1,4.4,1.4]
  23. Iris-versicolor:[5.6,3.0,4.5,1.5]
  24. Iris-versicolor:[5.8,2.7,4.1,1.0]
  25. Iris-versicolor:[6.2,2.2,4.5,1.5]
  26. Iris-versicolor:[5.6,2.5,3.9,1.1]
  27. Iris-versicolor:[5.9,3.2,4.8,1.8]
  28. Iris-versicolor:[6.1,2.8,4.0,1.3]
  29. Iris-versicolor:[6.3,2.5,4.9,1.5]
  30. Iris-versicolor:[6.1,2.8,4.7,1.2]
  31. Iris-versicolor:[6.4,2.9,4.3,1.3]
  32. Iris-versicolor:[6.6,3.0,4.4,1.4]
  33. Iris-versicolor:[6.8,2.8,4.8,1.4]
  34. Iris-versicolor:[6.7,3.0,5.0,1.7]
  35. Iris-versicolor:[6.0,2.9,4.5,1.5]
  36. Iris-versicolor:[5.7,2.6,3.5,1.0]
  37. Iris-versicolor:[5.5,2.4,3.8,1.1]
  38. Iris-versicolor:[5.5,2.4,3.7,1.0]
  39. Iris-versicolor:[5.8,2.7,3.9,1.2]
  40. Iris-versicolor:[6.0,2.7,5.1,1.6]
  41. Iris-versicolor:[5.4,3.0,4.5,1.5]
  42. Iris-versicolor:[6.0,3.4,4.5,1.6]
  43. Iris-versicolor:[6.7,3.1,4.7,1.5]
  44. Iris-versicolor:[6.3,2.3,4.4,1.3]
  45. Iris-versicolor:[5.6,3.0,4.1,1.3]
  46. Iris-versicolor:[5.5,2.5,4.0,1.3]
  47. Iris-versicolor:[5.5,2.6,4.4,1.2]
  48. Iris-versicolor:[6.1,3.0,4.6,1.4]
  49. Iris-versicolor:[5.8,2.6,4.0,1.2]
  50. Iris-versicolor:[5.0,2.3,3.3,1.0]
  51. Iris-versicolor:[5.6,2.7,4.2,1.3]
  52. Iris-versicolor:[5.7,3.0,4.2,1.2]
  53. Iris-versicolor:[5.7,2.9,4.2,1.3]
  54. Iris-versicolor:[6.2,2.9,4.3,1.3]
  55. Iris-versicolor:[5.1,2.5,3.0,1.1]
  56. Iris-versicolor:[5.7,2.8,4.1,1.3]
  57. Iris-virginica:[6.3,3.3,6.0,2.5]
  58. Iris-virginica:[5.8,2.7,5.1,1.9]
  59. Iris-virginica:[7.1,3.0,5.9,2.1]
  60. Iris-virginica:[6.3,2.9,5.6,1.8]
  61. Iris-virginica:[6.5,3.0,5.8,2.2]
  62. Iris-virginica:[7.6,3.0,6.6,2.1]
  63. Iris-virginica:[4.9,2.5,4.5,1.7]
  64. Iris-virginica:[7.3,2.9,6.3,1.8]
  65. Iris-virginica:[6.7,2.5,5.8,1.8]
  66. Iris-virginica:[7.2,3.6,6.1,2.5]
  67. Iris-virginica:[6.5,3.2,5.1,2.0]
  68. Iris-virginica:[6.4,2.7,5.3,1.9]
  69. Iris-virginica:[6.8,3.0,5.5,2.1]
  70. Iris-virginica:[5.7,2.5,5.0,2.0]
  71. Iris-virginica:[5.8,2.8,5.1,2.4]
  72. Iris-virginica:[6.4,3.2,5.3,2.3]
  73. Iris-virginica:[6.5,3.0,5.5,1.8]
  74. Iris-virginica:[7.7,3.8,6.7,2.2]
  75. Iris-virginica:[7.7,2.6,6.9,2.3]
  76. Iris-virginica:[6.0,2.2,5.0,1.5]
  77. Iris-virginica:[6.9,3.2,5.7,2.3]
  78. Iris-virginica:[5.6,2.8,4.9,2.0]
  79. Iris-virginica:[7.7,2.8,6.7,2.0]
  80. Iris-virginica:[6.3,2.7,4.9,1.8]
  81. Iris-virginica:[6.7,3.3,5.7,2.1]
  82. Iris-virginica:[7.2,3.2,6.0,1.8]
  83. Iris-virginica:[6.2,2.8,4.8,1.8]
  84. Iris-virginica:[6.1,3.0,4.9,1.8]
  85. Iris-virginica:[6.4,2.8,5.6,2.1]
  86. Iris-virginica:[7.2,3.0,5.8,1.6]
  87. Iris-virginica:[7.4,2.8,6.1,1.9]
  88. Iris-virginica:[7.9,3.8,6.4,2.0]
  89. Iris-virginica:[6.4,2.8,5.6,2.2]
  90. Iris-virginica:[6.3,2.8,5.1,1.5]
  91. Iris-virginica:[6.1,2.6,5.6,1.4]
  92. Iris-virginica:[7.7,3.0,6.1,2.3]
  93. Iris-virginica:[6.3,3.4,5.6,2.4]
  94. Iris-virginica:[6.4,3.1,5.5,1.8]
  95. Iris-virginica:[6.0,3.0,4.8,1.8]
  96. Iris-virginica:[6.9,3.1,5.4,2.1]
  97. Iris-virginica:[6.7,3.1,5.6,2.4]
  98. Iris-virginica:[6.9,3.1,5.1,2.3]
  99. Iris-virginica:[5.8,2.7,5.1,1.9]
  100. Iris-virginica:[6.8,3.2,5.9,2.3]
  101. Iris-virginica:[6.7,3.3,5.7,2.5]
  102. Iris-virginica:[6.7,3.0,5.2,2.3]
  103. Iris-virginica:[6.3,2.5,5.0,1.9]
  104. Iris-virginica:[6.5,3.0,5.2,2.0]
  105. Iris-virginica:[6.2,3.4,5.4,2.3]
  106. Iris-virginica:[5.9,3.0,5.1,1.8]
Python
3. 构建ML的pipeline

​ 分别获取标签列和特征列,进行索引,并进行了重命名。

  1. labelIndexer = StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
  2. featureIndexer = VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(df)
  3. featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_53b988077b38
Python

​ 接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。

  1. trainingData, testData = df.randomSplit([0.7,0.3])
Python

​ 然后,我们设置logistic的参数,这里我们统一用setter的方法来设置,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。这里我们设置了循环次数为10次,正则化项为0.3等,具体的可以设置的参数可以通过explainParams()来获取,还能看到我们已经设置的参数的结果。

  1. lr = LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
  2. print("LogisticRegression parameters:\n" + lr.explainParams())
  3.  
  4. LogisticRegression parameters:
  5. aggregationDepth: suggested depth for treeAggregate (>= 2). (default: 2)
  6. elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. (default: 0.0, current: 0.8)
  7. family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial (default: auto)
  8. featuresCol: features column name. (default: features, current: indexedFeatures)
  9. fitIntercept: whether to fit an intercept term. (default: True)
  10. labelCol: label column name. (default: label, current: indexedLabel)
  11. maxIter: max number of iterations (>= 0). (default: 100, current: 10)
  12. predictionCol: prediction column name. (default: prediction)
  13. probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. (default: probability)
  14. rawPredictionCol: raw prediction (a.k.a. confidence) column name. (default: rawPrediction)
  15. regParam: regularization parameter (>= 0). (default: 0.0, current: 0.3)
  16. standardization: whether to standardize the training features before fitting the model. (default: True)
  17. threshold: Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p]. (default: 0.5)
  18. thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. (undefined)
  19. tol: the convergence tolerance for iterative algorithms (>= 0). (default: 1e-06)
  20. weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. (undefined)
  21.  
Python

​ 这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。

  1. labelConverter = IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
Python

​ 构建pipeline,设置stage,然后调用fit()来训练模型。

  1. lrPipeline = Pipeline().setStages([labelIndexer, featureIndexer, lr, labelConverter])
  2. lrPipelineModel = lrPipeline.fit(trainingData)
Python

​ pipeline本质上是一个Estimator,当pipeline调用fit()的时候就产生了一个PipelineModel,本质上是一个Transformer。然后这个PipelineModel就可以调用transform()来进行预测,生成一个新的DataFrame,即利用训练得到的模型对测试集进行验证。

  1. lrPredictions = lrPipelineModel.transform(testData)
Python

​ 最后我们可以输出预测的结果,其中select选择要输出的列,collect获取所有行的数据,用foreach把每行打印出来。其中打印出来的值依次分别代表该行数据的真实分类和特征值、预测属于不同分类的概率、预测的分类。

  1. preRel = lrPredictions.select("predictedLabel", "label", "features", "probability").collect()
  2. for item in preRel:
  3. print(str(item['label'])+','+str(item['features'])+'-->prob='+str(item['probability'])+',predictedLabel'+str(item['predictedLabel']))
  4.  
  5. Iris-versicolor,[5.2,2.7,3.9,1.4]-->prob=[0.474125433289,0.525874566711],predictedLabelIris-virginica
  6. Iris-versicolor,[5.5,2.3,4.0,1.3]-->prob=[0.498724224708,0.501275775292],predictedLabelIris-virginica
  7. Iris-versicolor,[5.6,3.0,4.5,1.5]-->prob=[0.456659495584,0.543340504416],predictedLabelIris-virginica
  8. Iris-versicolor,[5.9,3.2,4.8,1.8]-->prob=[0.396329949568,0.603670050432],predictedLabelIris-virginica
  9. Iris-versicolor,[6.3,3.3,4.7,1.6]-->prob=[0.442287021876,0.557712978124],predictedLabelIris-virginica
  10. Iris-versicolor,[6.4,2.9,4.3,1.3]-->prob=[0.507813688472,0.492186311528],predictedLabelIris-versicolor
  11. Iris-versicolor,[6.7,3.0,5.0,1.7]-->prob=[0.42504596458,0.57495403542],predictedLabelIris-virginica
  12. Iris-virginica,[4.9,2.5,4.5,1.7]-->prob=[0.407378398796,0.592621601204],predictedLabelIris-virginica
  13. Iris-versicolor,[5.5,2.4,3.8,1.1]-->prob=[0.541810138118,0.458189861882],predictedLabelIris-versicolor
  14. Iris-versicolor,[5.5,2.5,4.0,1.3]-->prob=[0.498724224708,0.501275775292],predictedLabelIris-virginica
  15. Iris-versicolor,[5.5,2.6,4.4,1.2]-->prob=[0.520304937567,0.479695062433],predictedLabelIris-versicolor
  16. Iris-versicolor,[5.7,2.6,3.5,1.0]-->prob=[0.565147445416,0.434852554584],predictedLabelIris-versicolor
  17. Iris-versicolor,[5.8,2.7,3.9,1.2]-->prob=[0.523329193496,0.476670806504],predictedLabelIris-versicolor
  18. Iris-virginica,[5.8,2.7,5.1,1.9]-->prob=[0.374915000487,0.625084999513],predictedLabelIris-virginica
  19. Iris-virginica,[5.8,2.8,5.1,2.4]-->prob=[0.280289494118,0.719710505882],predictedLabelIris-virginica
  20. Iris-versicolor,[6.0,2.7,5.1,1.6]-->prob=[0.43929948357,0.56070051643],predictedLabelIris-virginica
  21. Iris-virginica,[6.1,3.0,4.9,1.8]-->prob=[0.398264741954,0.601735258046],predictedLabelIris-virginica
  22. Iris-virginica,[6.3,2.8,5.1,1.5]-->prob=[0.463684596073,0.536315403927],predictedLabelIris-virginica
  23. Iris-virginica,[6.3,3.3,6.0,2.5]-->prob=[0.267137732158,0.732862267842],predictedLabelIris-virginica
  24. Iris-virginica,[6.5,3.0,5.2,2.0]-->prob=[0.361404006967,0.638595993033],predictedLabelIris-virginica
  25. Iris-virginica,[6.5,3.0,5.5,1.8]-->prob=[0.402143821718,0.597856178282],predictedLabelIris-virginica
  26. Iris-virginica,[6.7,2.5,5.8,1.8]-->prob=[0.404087997603,0.595912002397],predictedLabelIris-virginica
  27. Iris-virginica,[6.9,3.1,5.4,2.1]-->prob=[0.345363437071,0.654636562929],predictedLabelIris-virginica
  28. Iris-virginica,[7.3,2.9,6.3,1.8]-->prob=[0.409938392379,0.590061607621],predictedLabelIris-virginica
  29. Iris-virginica,[7.4,2.8,6.1,1.9]-->prob=[0.390181923993,0.609818076007],predictedLabelIris-virginica
Python
4. 模型评估

​ 创建一个MulticlassClassificationEvaluator实例,用setter方法把预测分类的列名和真实分类的列名进行设置;然后计算预测准确率和错误率。

  1. evaluator = MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
  2. lrAccuracy = evaluator.evaluate(lrPredictions)
  3. print("Test Error = " + str(1.0 - lrAccuracy))
  4. Test Error = 0.35111111111111115
Python

​ 从上面可以看到预测的准确性达到65%,接下来我们可以通过model来获取我们训练得到的逻辑斯蒂模型。前面已经说过model是一个PipelineModel,因此我们可以通过调用它的stages来获取模型,具体如下:

  1. lrModel = lrPipelineModel.stages[2]
  2. print("Coefficients: " + str(lrModel.coefficients)+"Intercept: "+str(lrModel.intercept)+"numClasses: "+str(lrModel.numClasses)+"numFeatures: "+str(lrModel.numFeatures))
  3. Coefficients: [-0.0396171957643483,0.0,0.0,0.07240315639651046]Intercept: -0.23127346342015379numClasses: 2numFeatures: 4
Python
5. 模型评估

​ spark的ml库还提供了一个对模型的摘要总结(summary),不过目前只支持二项逻辑斯蒂回归,而且要显示转化成BinaryLogisticRegressionSummary 。在下面的代码中,首先获得二项逻辑斯模型的摘要;然后获得10次循环中损失函数的变化,并将结果打印出来,可以看到损失函数随着循环是逐渐变小的,损失函数越小,模型就越好;接下来,我们把摘要强制转化为BinaryLogisticRegressionSummary ,来获取用来评估模型性能的矩阵;通过获取ROC,我们可以判断模型的好坏,areaUnderROC达到了 0.969551282051282,说明我们的分类器还是不错的;最后,我们通过最大化fMeasure来选取最合适的阈值,其中fMeasure是一个综合了召回率和准确率的指标,通过最大化fMeasure,我们可以选取到用来分类的最合适的阈值。

  1. trainingSummary = lrModel.summary
  2. objectiveHistory = trainingSummary.objectiveHistory
  3. for item in objectiveHistory:
  4. ... print(item)
  5. ...
  6. 0.6930582890371242
  7. 0.6899151958544979
  8. 0.6884360489604017
  9. 0.6866214680339037
  10. 0.6824264404293411
  11. 0.6734525297891238
  12. 0.6718869589782477
  13. 0.6700321119842002
  14. 0.6681741952485035
  15. 0.6668744860924799
  16. 0.6656055740433819
  17.  
  18.  
  19. print(trainingSummary.areaUnderROC)
  20. 0.9889758179231863
  21.  
  22. fMeasure = trainingSummary.fMeasureByThreshold
  23.  
  24. maxFMeasure = fMeasure.select(functions.max("F-Measure")).head()[0]
  25. 0.9599999999999999
  26.  
  27. bestThreshold = fMeasure.where(fMeasure["F-Measure"]== maxFMeasure).select("threshold").head()[0]
  28. 0.5487261156903904
  29.  
  30. lr.setThreshold(bestThreshold)
Python

用多项逻辑斯蒂回归解决 二分类 问题

​ 对于二分类问题,我们还可以用多项逻辑斯蒂回归进行多分类分析。多项逻辑斯蒂回归与二项逻辑斯蒂回归类似,只是在模型设置上把family参数设置成multinomial,这里我们仅列出结果:

  1. mlr = LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")
  2.  
  3. mlrPipeline = Pipeline().setStages([labelIndexer, featureIndexer, mlr, labelConverter])
  4.  
  5. mlrPipelineModel = mlrPipeline.fit(trainingData)
  6.  
  7. mlrPreRel = mlrPredictions.select("predictedLabel", "label", "features", "probability").collect()
  8. for item in mlrPreRel:
  9. print('('+str(item['label'])+','+str(item['features'])+')-->prob='+str(item['probability'])+',predictLabel='+str(item['predictedLabel']))
  10.  
  11. (Iris-versicolor,[5.2,2.7,3.9,1.4])-->prob=[0.769490937876,0.230509062124],predictLabel=Iris-versicolor
  12. (Iris-versicolor,[5.5,2.3,4.0,1.3])-->prob=[0.795868931454,0.204131068546],predictLabel=Iris-versicolor
  13. (Iris-versicolor,[5.6,3.0,4.5,1.5])-->prob=[0.740814410092,0.259185589908],predictLabel=Iris-versicolor
  14. (Iris-versicolor,[5.9,3.2,4.8,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  15. (Iris-versicolor,[6.3,3.3,4.7,1.6])-->prob=[0.709915262079,0.290084737921],predictLabel=Iris-versicolor
  16. (Iris-versicolor,[6.4,2.9,4.3,1.3])-->prob=[0.795868931454,0.204131068546],predictLabel=Iris-versicolor
  17. (Iris-versicolor,[6.7,3.0,5.0,1.7])-->prob=[0.676938845836,0.323061154164],predictLabel=Iris-versicolor
  18. (Iris-virginica,[4.9,2.5,4.5,1.7])-->prob=[0.676938845836,0.323061154164],predictLabel=Iris-versicolor
  19. (Iris-versicolor,[5.5,2.4,3.8,1.1])-->prob=[0.841727580252,0.158272419748],predictLabel=Iris-versicolor
  20. (Iris-versicolor,[5.5,2.5,4.0,1.3])-->prob=[0.795868931454,0.204131068546],predictLabel=Iris-versicolor
  21. (Iris-versicolor,[5.5,2.6,4.4,1.2])-->prob=[0.819934748118,0.180065251882],predictLabel=Iris-versicolor
  22. (Iris-versicolor,[5.7,2.6,3.5,1.0])-->prob=[0.861328953172,0.138671046828],predictLabel=Iris-versicolor
  23. (Iris-versicolor,[5.8,2.7,3.9,1.2])-->prob=[0.819934748118,0.180065251882],predictLabel=Iris-versicolor
  24. (Iris-virginica,[5.8,2.7,5.1,1.9])-->prob=[0.605700014665,0.394299985335],predictLabel=Iris-versicolor
  25. (Iris-virginica,[5.8,2.8,5.1,2.4])-->prob=[0.414135982436,0.585864017564],predictLabel=Iris-virginica
  26. (Iris-versicolor,[6.0,2.7,5.1,1.6])-->prob=[0.709915262079,0.290084737921],predictLabel=Iris-versicolor
  27. (Iris-virginica,[6.1,3.0,4.9,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  28. (Iris-virginica,[6.3,2.8,5.1,1.5])-->prob=[0.740814410092,0.259185589908],predictLabel=Iris-versicolor
  29. (Iris-virginica,[6.3,3.3,6.0,2.5])-->prob=[0.377041048688,0.622958951312],predictLabel=Iris-virginica
  30. (Iris-virginica,[6.5,3.0,5.2,2.0])-->prob=[0.568084351171,0.431915648829],predictLabel=Iris-versicolor
  31. (Iris-virginica,[6.5,3.0,5.5,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  32. (Iris-virginica,[6.7,2.5,5.8,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  33. (Iris-virginica,[6.9,3.1,5.4,2.1])-->prob=[0.529666704954,0.470333295046],predictLabel=Iris-versicolor
  34. (Iris-virginica,[7.3,2.9,6.3,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  35. (Iris-virginica,[7.4,2.8,6.1,1.9])-->prob=[0.605700014665,0.394299985335],predictLabel=Iris-versicolor
  36.  
  37. mlrAccuracy = evaluator.evaluate(mlrPredictions)
  38.  
  39. print("Test Error = " + str(1.0 - mlrAccuracy))
  40.  
  41. Test Error = 0.48730158730158735
  42.  
  43.  
  44. mlrModel = mlrPipelineModel.stages[2]
  45.  
  46. print("Multinomial coefficients: " +str(mlrModel.coefficientMatrix)+"Multin
  47. omial intercepts: "+str(mlrModel.interceptVector)+"numClasses: "+str(mlrModel.numClasses)+
  48. "numFeatures: "+str(mlrModel.numFeatures))
  49.  
  50. Multinomial coefficients: 2 X 4 CSRMatrix
  51. (0,3) -0.0776
  52. (1,3) 0.0776Multinomial intercepts: [0.913185966051,-0.913185966051]numClasses: 2numFeatures: 4
  53.  
Python

用多项逻辑斯蒂回归解决 多分类 问题

​ 对于多分类问题,我们需要用多项逻辑斯蒂回归进行多分类分析。这里我们用全部的iris数据集,即有三个类别,过程与上述基本一致,这里我们同样仅列出结果:

  1. mlrPreRel = mlrPredictions.select("predictedLabel", "label", "features", "probability").collect()
  2. for item in mlrPreRel:
  3. print('('+str(item['label'])+','+str(item['features'])+')-->prob='+str(item['probability'])+',predictLabel='+str(item['predictedLabel']))
  4.  
  5. (Iris-versicolor,[5.2,2.7,3.9,1.4])-->prob=[0.769490937876,0.230509062124],predictLabel=Iris-versicolor
  6. (Iris-versicolor,[5.5,2.3,4.0,1.3])-->prob=[0.795868931454,0.204131068546],predictLabel=Iris-versicolor
  7. (Iris-versicolor,[5.6,3.0,4.5,1.5])-->prob=[0.740814410092,0.259185589908],predictLabel=Iris-versicolor
  8. (Iris-versicolor,[5.9,3.2,4.8,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  9. (Iris-versicolor,[6.3,3.3,4.7,1.6])-->prob=[0.709915262079,0.290084737921],predictLabel=Iris-versicolor
  10. (Iris-versicolor,[6.4,2.9,4.3,1.3])-->prob=[0.795868931454,0.204131068546],predictLabel=Iris-versicolor
  11. (Iris-versicolor,[6.7,3.0,5.0,1.7])-->prob=[0.676938845836,0.323061154164],predictLabel=Iris-versicolor
  12. (Iris-virginica,[4.9,2.5,4.5,1.7])-->prob=[0.676938845836,0.323061154164],predictLabel=Iris-versicolor
  13. (Iris-versicolor,[5.5,2.4,3.8,1.1])-->prob=[0.841727580252,0.158272419748],predictLabel=Iris-versicolor
  14. (Iris-versicolor,[5.5,2.5,4.0,1.3])-->prob=[0.795868931454,0.204131068546],predictLabel=Iris-versicolor
  15. (Iris-versicolor,[5.5,2.6,4.4,1.2])-->prob=[0.819934748118,0.180065251882],predictLabel=Iris-versicolor
  16. (Iris-versicolor,[5.7,2.6,3.5,1.0])-->prob=[0.861328953172,0.138671046828],predictLabel=Iris-versicolor
  17. (Iris-versicolor,[5.8,2.7,3.9,1.2])-->prob=[0.819934748118,0.180065251882],predictLabel=Iris-versicolor
  18. (Iris-virginica,[5.8,2.7,5.1,1.9])-->prob=[0.605700014665,0.394299985335],predictLabel=Iris-versicolor
  19. (Iris-virginica,[5.8,2.8,5.1,2.4])-->prob=[0.414135982436,0.585864017564],predictLabel=Iris-virginica
  20. (Iris-versicolor,[6.0,2.7,5.1,1.6])-->prob=[0.709915262079,0.290084737921],predictLabel=Iris-versicolor
  21. (Iris-virginica,[6.1,3.0,4.9,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  22. (Iris-virginica,[6.3,2.8,5.1,1.5])-->prob=[0.740814410092,0.259185589908],predictLabel=Iris-versicolor
  23. (Iris-virginica,[6.3,3.3,6.0,2.5])-->prob=[0.377041048688,0.622958951312],predictLabel=Iris-virginica
  24. (Iris-virginica,[6.5,3.0,5.2,2.0])-->prob=[0.568084351171,0.431915648829],predictLabel=Iris-versicolor
  25. (Iris-virginica,[6.5,3.0,5.5,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  26. (Iris-virginica,[6.7,2.5,5.8,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  27. (Iris-virginica,[6.9,3.1,5.4,2.1])-->prob=[0.529666704954,0.470333295046],predictLabel=Iris-versicolor
  28. (Iris-virginica,[7.3,2.9,6.3,1.8])-->prob=[0.64210359201,0.35789640799],predictLabel=Iris-versicolor
  29. (Iris-virginica,[7.4,2.8,6.1,1.9])-->prob=[0.605700014665,0.394299985335],predictLabel=Iris-versicolor
  30.  
  31.  
  32. mlrAccuracy = evaluator.evaluate(mlrPredictions)
  33.  
  34. println("Test Error = " + str(1.0 - mlrAccuracy))
  35. Test Error = 0.48730158730158735
  36.  
  37. mlrModel = mlrPipelineModel.stages[2]
  38.  
  39. print("Multinomial coefficients: " + str(mlrModel.coefficientMatrix)+"Multinomial intercepts: "+str(mlrModel.interceptVector)+"numClasses: "+str(mlrModel.numClasses)+"numFeatures: "+str(mlrModel.numFeatures))
  40.  
  41. Multinomial coefficients: 2 X 4 CSRMatrix
  42. (0,3) -0.0776
  43. (1,3) 0.0776Multinomial intercepts: [0.913185966051,-0.913185966051]numClasses: 2numFeatures: 4
  44.  
Python



子雨大数据之Spark入门
扫一扫访问本博客