BigData/Spark

Spark ML 05(Pyspark)

Jonghee Jeon 2020. 5. 4. 14:57

SparkML을 이용하여 RandomForest를 수행하는 예제입니다.


RandomForest?

 랜덤 포레스트(영어: random forest)는 분류, 회귀 분석 등에 사용되는 앙상블 학습 방법의 일종으로, 훈련 과정에서 구성한 다수의 결정 트리로부터 부류(분류) 또는 평균 예측치(회귀 분석)를 출력함으로써 동작한다.

 

결정 트리 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전.

ko.wikipedia.org

Spark ML

 

In [1]:
df = spark.read.csv("data/affairs.csv", inferSchema=True, header=True)
In [2]:
df.printSchema()
 
root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)

In [7]:
df.show(5)
 
+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows

In [3]:
df.summary().show()
 
+-------+------------------+------------------+-----------------+------------------+------------------+------------------+
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|           affairs|
+-------+------------------+------------------+-----------------+------------------+------------------+------------------+
|  count|              6366|              6366|             6366|              6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|0.3224945020420987|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785| 0.467467779921086|
|    min|                 1|              17.5|              0.5|               0.0|                 1|                 0|
|    25%|                 4|              22.0|              2.5|               0.0|                 2|                 0|
|    50%|                 4|              27.0|              6.0|               1.0|                 2|                 0|
|    75%|                 5|              32.0|             16.5|               2.0|                 3|                 1|
|    max|                 5|              42.0|             23.0|               5.5|                 4|                 1|
+-------+------------------+------------------+-----------------+------------------+------------------+------------------+

In [4]:
df.groupBy('affairs').count().show()
 
+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+

In [5]:
df.groupBy('rate_marriage').count().show()
 
+-------------+-----+
|rate_marriage|count|
+-------------+-----+
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
+-------------+-----+

In [6]:
df.groupBy('children', 'affairs').count().orderBy('children', 'affairs', 'count', ascending=True).show()
 
+--------+-------+-----+
|children|affairs|count|
+--------+-------+-----+
|     0.0|      0| 1912|
|     0.0|      1|  502|
|     1.0|      0|  747|
|     1.0|      1|  412|
|     2.0|      0|  873|
|     2.0|      1|  608|
|     3.0|      0|  460|
|     3.0|      1|  321|
|     4.0|      0|  197|
|     4.0|      1|  131|
|     5.5|      0|  124|
|     5.5|      1|   79|
+--------+-------+-----+

In [8]:
from pyspark.ml.feature import VectorAssembler
In [12]:
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married' \
                                          , 'children', 'religious'], outputCol="features")
In [13]:
df = df_assembler.transform(df)
In [14]:
df.printSchema()
 
root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)
 |-- features: vector (nullable = true)

In [15]:
df.select(['features', 'affairs']).show(5)
 
+--------------------+-------+
|            features|affairs|
+--------------------+-------+
|[5.0,32.0,6.0,1.0...|      0|
|[4.0,22.0,2.5,0.0...|      0|
|[3.0,32.0,9.0,3.0...|      1|
|[3.0,27.0,13.0,3....|      1|
|[4.0,22.0,2.5,0.0...|      1|
+--------------------+-------+
only showing top 5 rows

In [16]:
model_df = df.select(['features', 'affairs'])
In [57]:
trainDF, testDF = model_df.randomSplit([0.8, 0.2])
In [58]:
from pyspark.ml.classification import RandomForestClassifier
In [65]:
rf_model = RandomForestClassifier(labelCol='affairs', numTrees=30, maxDepth=3, impurity='entropy').fit(trainDF)
In [66]:
rf_predictions = rf_model.transform(testDF)
In [67]:
rf_predictions.show(5)
 
+--------------------+-------+--------------------+--------------------+----------+
|            features|affairs|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|[1.0,22.0,2.5,1.0...|      0|[13.6946596452935...|[0.45648865484311...|       1.0|
|[1.0,22.0,2.5,1.0...|      0|[13.6946596452935...|[0.45648865484311...|       1.0|
|[1.0,27.0,2.5,0.0...|      1|[14.4414315098290...|[0.48138105032763...|       1.0|
|[1.0,27.0,6.0,0.0...|      0|[12.2603006986020...|[0.40867668995340...|       1.0|
|[1.0,27.0,6.0,1.0...|      1|[11.2445159991833...|[0.37481719997277...|       1.0|
+--------------------+-------+--------------------+--------------------+----------+
only showing top 5 rows

In [68]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
In [69]:
rf_accuracy = MulticlassClassificationEvaluator(labelCol="affairs", metricName="accuracy") \
                                                .evaluate(rf_predictions)
In [70]:
rf_accuracy
Out[70]:
0.7132644956314536
In [71]:
rf_model.save("/home/carbig/RandomForest_Model")
In [72]:
rf_model2 = rf_model.load("/home/carbig/RandomForest_Model")
In [ ]: