관리 메뉴

Hee'World

Spark ML 03 (Pyspark) 본문

BigData/Spark

Spark ML 03 (Pyspark)

Jonghee Jeon 2020. 5. 1. 13:21

Spark ML의 로지스틱 회귀를 수행하는 예제


  로지스틱 회귀(영어: logistic regression)는 영국의 통계학자인 D. R. Cox가 1958년[1] 에 제안한 확률 모델로서 독립 변수의 선형 결합을 이용하여 사건의 발생 가능성을 예측하는데 사용되는 통계 기법이다.
로지스틱 회귀의 목적은 일반적인 회귀 분석의 목표와 동일하게 종속 변수와 독립 변수간의 관계를 구체적인 함수로 나타내어 향후 예측 모델에 사용하는 것이다. 이는 독립 변수의 선형 결합으로 종속 변수를 설명한다는 관점에서는 선형 회귀 분석과 유사하다. 하지만 로지스틱 회귀는 선형 회귀 분석과는 다르게 종속 변수가 범주형 데이터를 대상으로 하며 입력 데이터가 주어졌을 때 해당 데이터의 결과가 특정 분류로 나뉘기 때문에 일종의 분류 (classification) 기법으로도 볼 수 있다.
  흔히 로지스틱 회귀는 종속변수가 이항형 문제(즉, 유효한 범주의 개수가 두개인 경우)를 지칭할 때 사용된다. 이외에, 두 개 이상의 범주를 가지는 문제가 대상인 경우엔 다항 로지스틱 회귀 (multinomial logistic regression) 또는 분화 로지스틱 회귀 (polytomous logistic regression)라고 하고 복수의 범주이면서 순서가 존재하면 서수 로지스틱 회귀 (ordinal logistic regression) 라고 한다.[2] 로지스틱 회귀 분석은 의료, 통신, 데이터마이닝과 같은 다양한 분야에서 분류 및 예측을 위한 모델로서 폭넓게 사용되고 있다.

https://ko.wikipedia.org/wiki/%EB%A1%9C%EC%A7%80%EC%8A%A4%ED%8B%B1_%ED%9A%8C%EA%B7%80

 

로지스틱 회귀 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. 둘러보기로 가기 검색하러 가기 로지스틱 회귀(영어: logistic regression)는 영국의 통계학자인 D. R. Cox가 1958년[1] 에 제안한 확률 모델로서 독립 변수의 선형 결합��

ko.wikipedia.org

 

 

In [1]:
df = spark.read.csv("data/Log_Reg_dataset.csv", inferSchema=True, header=True)
In [2]:
df.count()
Out[2]:
20000
In [3]:
len(df.columns)
Out[3]:
6
In [4]:
df.printSchema()
 
root
 |-- Country: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Repeat_Visitor: integer (nullable = true)
 |-- Platform: string (nullable = true)
 |-- Web_pages_viewed: integer (nullable = true)
 |-- Status: integer (nullable = true)

In [5]:
df.show(5)
 
+---------+---+--------------+--------+----------------+------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|
+---------+---+--------------+--------+----------------+------+
|    India| 41|             1|   Yahoo|              21|     1|
|   Brazil| 28|             1|   Yahoo|               5|     0|
|   Brazil| 40|             0|  Google|               3|     0|
|Indonesia| 31|             1|    Bing|              15|     1|
| Malaysia| 32|             0|  Google|              15|     1|
+---------+---+--------------+--------+----------------+------+
only showing top 5 rows

In [6]:
df.summary().show()
 
+-------+--------+-----------------+-----------------+--------+-----------------+------------------+
|summary| Country|              Age|   Repeat_Visitor|Platform| Web_pages_viewed|            Status|
+-------+--------+-----------------+-----------------+--------+-----------------+------------------+
|  count|   20000|            20000|            20000|   20000|            20000|             20000|
|   mean|    null|         28.53955|           0.5029|    null|           9.5533|               0.5|
| stddev|    null|7.888912950773227|0.500004090187782|    null|6.073903499824976|0.5000125004687693|
|    min|  Brazil|               17|                0|    Bing|                1|                 0|
|    25%|    null|               22|                0|    null|                4|                 0|
|    50%|    null|               27|                1|    null|                9|                 0|
|    75%|    null|               34|                1|    null|               14|                 1|
|    max|Malaysia|              111|                1|   Yahoo|               29|                 1|
+-------+--------+-----------------+-----------------+--------+-----------------+------------------+

In [7]:
df.groupBy('Platform').count().show()
 
+--------+-----+
|Platform|count|
+--------+-----+
|   Yahoo| 9859|
|    Bing| 4360|
|  Google| 5781|
+--------+-----+

In [8]:
df.groupBy('Country').count().show()
 
+---------+-----+
|  Country|count|
+---------+-----+
| Malaysia| 1218|
|    India| 4018|
|Indonesia|12178|
|   Brazil| 2586|
+---------+-----+

In [9]:
df.groupBy('Status').count().show()
 
+------+-----+
|Status|count|
+------+-----+
|     1|10000|
|     0|10000|
+------+-----+

In [10]:
df.groupBy('Country').mean().show()
 
+---------+------------------+-------------------+---------------------+--------------------+
|  Country|          avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|         avg(Status)|
+---------+------------------+-------------------+---------------------+--------------------+
| Malaysia|27.792282430213465| 0.5730706075533661|   11.192118226600986|  0.6568144499178982|
|    India|27.976854156296664| 0.5433051269288203|   10.727227476356397|  0.6212045793927327|
|Indonesia| 28.43159796354081| 0.5207751683363442|    9.985711939563148|  0.5422893742814913|
|   Brazil|30.274168600154677|  0.322892498066512|    4.921113689095128|0.038669760247486466|
+---------+------------------+-------------------+---------------------+--------------------+

In [11]:
df.groupBy('Platform').mean().show()
 
+--------+------------------+-------------------+---------------------+------------------+
|Platform|          avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|       avg(Status)|
+--------+------------------+-------------------+---------------------+------------------+
|   Yahoo|28.569226087838523| 0.5094837204584644|    9.599655137437875|0.5071508266558474|
|    Bing| 28.68394495412844| 0.4720183486238532|    9.114908256880733|0.4559633027522936|
|  Google|28.380038055699707| 0.5149628092025601|    9.804878048780488|0.5210171250648676|
+--------+------------------+-------------------+---------------------+------------------+

In [12]:
df.groupBy('Status').mean().show()
 
+------+--------+-------------------+---------------------+-----------+
|Status|avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|avg(Status)|
+------+--------+-------------------+---------------------+-----------+
|     1| 26.5435|             0.7019|              14.5617|        1.0|
|     0| 30.5356|             0.3039|               4.5449|        0.0|
+------+--------+-------------------+---------------------+-----------+

In [13]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder
In [14]:
platform_indexer = StringIndexer(inputCol='Platform', \
                                 outputCol='Platform_Num').fit(df)
In [15]:
df = platform_indexer.transform(df)
df.show(5)
 
+---------+---+--------------+--------+----------------+------+------------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_Num|
+---------+---+--------------+--------+----------------+------+------------+
|    India| 41|             1|   Yahoo|              21|     1|         0.0|
|   Brazil| 28|             1|   Yahoo|               5|     0|         0.0|
|   Brazil| 40|             0|  Google|               3|     0|         1.0|
|Indonesia| 31|             1|    Bing|              15|     1|         2.0|
| Malaysia| 32|             0|  Google|              15|     1|         1.0|
+---------+---+--------------+--------+----------------+------+------------+
only showing top 5 rows

In [16]:
platform_encoder = OneHotEncoder(inputCol='Platform_Num', outputCol='Platform_Vector')
In [18]:
df = platform_encoder.transform(df)
df.show(5)
 
+---------+---+--------------+--------+----------------+------+------------+---------------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_Num|Platform_Vector|
+---------+---+--------------+--------+----------------+------+------------+---------------+
|    India| 41|             1|   Yahoo|              21|     1|         0.0|  (2,[0],[1.0])|
|   Brazil| 28|             1|   Yahoo|               5|     0|         0.0|  (2,[0],[1.0])|
|   Brazil| 40|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|
|Indonesia| 31|             1|    Bing|              15|     1|         2.0|      (2,[],[])|
| Malaysia| 32|             0|  Google|              15|     1|         1.0|  (2,[1],[1.0])|
+---------+---+--------------+--------+----------------+------+------------+---------------+
only showing top 5 rows

In [19]:
df.groupBy('Platform').count().orderBy('count', ascending=False).show(5)
 
+--------+-----+
|Platform|count|
+--------+-----+
|   Yahoo| 9859|
|  Google| 5781|
|    Bing| 4360|
+--------+-----+

In [20]:
df.groupBy('Platform_Num').count().orderBy('count', ascending=False).show(5)
 
+------------+-----+
|Platform_Num|count|
+------------+-----+
|         0.0| 9859|
|         1.0| 5781|
|         2.0| 4360|
+------------+-----+

In [21]:
df.groupBy('Platform_Vector').count().orderBy('count', ascending=False).show(5)
 
+---------------+-----+
|Platform_Vector|count|
+---------------+-----+
|  (2,[0],[1.0])| 9859|
|  (2,[1],[1.0])| 5781|
|      (2,[],[])| 4360|
+---------------+-----+

In [23]:
country_indexer = StringIndexer(inputCol='Country', \
                                outputCol='Country_Num').fit(df)
df = country_indexer.transform(df)
In [26]:
country_encoder =  OneHotEncoder(inputCol='Country_Num', \
                                 outputCol='Country_Vector')
df = country_encoder.transform(df)
In [27]:
df.select(['Country', 'Country_Num', 'Country_Vector']).show(5)
 
+---------+-----------+--------------+
|  Country|Country_Num|Country_Vector|
+---------+-----------+--------------+
|    India|        1.0| (3,[1],[1.0])|
|   Brazil|        2.0| (3,[2],[1.0])|
|   Brazil|        2.0| (3,[2],[1.0])|
|Indonesia|        0.0| (3,[0],[1.0])|
| Malaysia|        3.0|     (3,[],[])|
+---------+-----------+--------------+
only showing top 5 rows

In [28]:
from pyspark.ml.feature import VectorAssembler
In [30]:
df_assembler = VectorAssembler(inputCols=['Platform_Vector', 'Country_Vector', 'Age', \
                                          'Repeat_Visitor','Web_pages_viewed'], \
                               outputCol='features')
df = df_assembler.transform(df)
In [31]:
df.printSchema()
 
root
 |-- Country: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Repeat_Visitor: integer (nullable = true)
 |-- Platform: string (nullable = true)
 |-- Web_pages_viewed: integer (nullable = true)
 |-- Status: integer (nullable = true)
 |-- Platform_Num: double (nullable = false)
 |-- Platform_Vector: vector (nullable = true)
 |-- Country_Num: double (nullable = false)
 |-- Country_Vector: vector (nullable = true)
 |-- features: vector (nullable = true)

In [32]:
df.select(['features', 'Status']).show(5)
 
+--------------------+------+
|            features|Status|
+--------------------+------+
|[1.0,0.0,0.0,1.0,...|     1|
|[1.0,0.0,0.0,0.0,...|     0|
|(8,[1,4,5,7],[1.0...|     0|
|(8,[2,5,6,7],[1.0...|     1|
|(8,[1,5,7],[1.0,3...|     1|
+--------------------+------+
only showing top 5 rows

In [33]:
model_df = df.select(['features', 'Status'])
In [34]:
from pyspark.ml.classification import LogisticRegression
In [35]:
train_df, test_df = model_df.randomSplit([0.7, 0.3])
In [37]:
train_df.count()
Out[37]:
13842
In [38]:
test_df.count()
Out[38]:
6158
In [40]:
log_reg_model = LogisticRegression(labelCol='Status')
log_reg_model_fit = log_reg_model.fit(train_df)
In [41]:
log_reg_model_fit.coefficients
Out[41]:
DenseVector([0.1943, 0.2312, -0.4285, -0.1042, -3.7871, -0.0679, 1.7362, 0.7386])
In [42]:
log_reg_model_fit.intercept
Out[42]:
-5.19530354724406
In [43]:
train_result = log_reg_model_fit.evaluate(train_df).predictions
In [44]:
correct_preds = train_result.filter(train_result['Status']==1) \
                            .filter(train_result['prediction']==1).count()
In [45]:
train_df.filter(train_df['Status']==1).count()
Out[45]:
6905
In [46]:
correct_preds
Out[46]:
6476
In [47]:
correct_preds/train_df.filter(train_df['Status']==1).count()
Out[47]:
0.9378711078928312
In [48]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
In [49]:
results = log_reg_model_fit.evaluate(test_df).predictions
In [50]:
results.select(['Status', 'prediction']).show(20)
 
+------+----------+
|Status|prediction|
+------+----------+
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     1|       0.0|
|     1|       0.0|
|     0|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
|     1|       1.0|
+------+----------+
only showing top 20 rows

In [51]:
results[(results.Status == 1) & (results.prediction == 1)].count()
Out[51]:
2906
In [52]:
results[(results.Status == 1) & (results.prediction == 0)].count()
Out[52]:
189
In [55]:
results[(results.Status == 0) & (results.prediction == 0)].count()
Out[55]:
2881
In [54]:
results[(results.Status == 0) & (results.prediction == 1)].count()
Out[54]:
182
In [ ]:
 

'BigData > Spark' 카테고리의 다른 글

Spark ML 05(Pyspark)  (0) 2020.05.04
Spark ML 04(Pyspark)  (0) 2020.05.03
Spark ML 02 (Pyspark)  (0) 2020.04.26
Spark ML (Pyspark)  (0) 2020.04.25
Spark Streaming (PySpark)  (0) 2020.04.21
Comments