관리 메뉴

Hee'World

Tensorflow_R_MNIST 예제 (Keras) 본문

Programming/R

Tensorflow_R_MNIST 예제 (Keras)

Jonghee Jeon 2020. 4. 13. 23:51

Tensorflow를 R에서 테스트 진행하여 가장 기본 예제인 MNIST를 사용합니다.

R - 3.5.3

RStudio - 1.1.463

OS - Windows10

Mem - 16G


참고 - https://tensorflow.rstudio.com/tutorials/beginners/
 

TensorFlow for R

This short introduction uses Keras to: Build a neural network that classifies images.Train this neural network.And, finally, evaluate the accuracy of the model.Save and restore the created model. Before running the quickstart you need to have Keras install

tensorflow.rstudio.com

 

#install.packages("tensorflow")
#install.packages("keras")

# Tensorflow 라이브러리 로드
library(tensorflow)
#install_tensorflow()

# dplyr / keras 라이브러리 로드
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(keras)

# MNIST 데이터 셋 로드
mnist <- dataset_mnist()
mnist$train$x <- mnist$train$x/255
mnist$test$x <- mnist$test$x/255

# 모델 생성
# Activation - RELU
# DROPOUT 0.2
# ACtivation - SOFTMAX
model <- keras_model_sequential() %>% 
  layer_flatten(input_shape = c(28, 28)) %>% 
  layer_dense(units = 128, activation = "relu") %>% 
  layer_dropout(0.2) %>% 
  layer_dense(10, activation = "softmax")

summary(model)
## Model: "sequential"
## ________________________________________________________________________________
## Layer (type)                        Output Shape                    Param #     
## ================================================================================
## flatten (Flatten)                   (None, 784)                     0           
## ________________________________________________________________________________
## dense (Dense)                       (None, 128)                     100480      
## ________________________________________________________________________________
## dropout (Dropout)                   (None, 128)                     0           
## ________________________________________________________________________________
## dense_1 (Dense)                     (None, 10)                      1290        
## ================================================================================
## Total params: 101,770
## Trainable params: 101,770
## Non-trainable params: 0
## ________________________________________________________________________________
# LOSS FUNCTION - Sparse Categoricla Crossentropy
# Optimizer - ADAM
# Metrics - ACCURACY
model %>% 
  compile(
    loss = "sparse_categorical_crossentropy",
    optimizer = "adam",
    metrics = "accuracy"
  )

# epochs - 5
model %>% 
  fit(
    x = mnist$train$x, y = mnist$train$y,
    epochs = 5,
    validation_split = 0.3,
    verbose = 2
  )

# prediction
predictions <- predict(model, mnist$test$x)
head(predictions, 2)
##              [,1]         [,2]         [,3]         [,4]         [,5]
## [1,] 1.580725e-07 2.300006e-08 2.606068e-05 4.458812e-04 9.343597e-11
## [2,] 7.055627e-08 2.735742e-04 9.996517e-01 6.442662e-05 1.110566e-13
##              [,6]         [,7]         [,8]         [,9]        [,10]
## [1,] 8.591602e-07 5.672325e-11 9.995247e-01 3.809913e-07 1.920494e-06
## [2,] 2.437132e-06 9.999303e-08 5.518335e-13 7.729958e-06 5.565607e-13
# Evaluate
model %>% 
  evaluate(mnist$test$x, mnist$test$y, verbose = 0)
## $loss
## [1] 0.08904964
## 
## $accuracy
## [1] 0.9734
# 생성된 모델 저장
save_model_tf(object = model, filepath = "model")

# 저장된 모델 로드
reloaded_model <- load_model_tf("model")
all.equal(predict(model, mnist$test$x), predict(reloaded_model, mnist$test$x))
## [1] TRUE

 

'Programming > R' 카테고리의 다른 글

R 버전 업데이트  (0) 2021.12.11
기상데이터를 이용한 Shiny App구현  (0) 2020.03.23
tensorflow in r 설치  (0) 2017.07.19
Sparklyr 설치  (0) 2017.07.19
kNN 알고리즘  (0) 2015.05.03
Comments