숫자인식 코드 분석해보기
지난 포스팅에서 인공지능 이론중 가장 간단한 Linear Regression과 Logistic Regression에 대해서 설명했습니다. 이번 포스팅에선 tensorflow 튜토리얼에 있는 숫자 인식 코드를 차근차근 분석해가며 어떻게 우리가 배운 이론이 접목되었는지 공부해보는 시간을 가지려고 합니다.
먼저 숫자 인식 코드는 여기를 공식 튜토리얼 페이지는 여기를 클릭하면 볼 수 있습니다.
튜토리얼 페이지를 클릭하시면 MNIST라는 단어를 보실 수 있을겁니다. 이건 컴퓨터 비전에 사용되는 데이터 셋을 말하는 용어니 너무 주의깊게 보시지 않아도 괜찮아요. 그냥 이런 이미지를 가지고 있다는 것만 기억하시면 됩니다!
위의 이미지들은 우리가 숫자를 직접 손으로 쓸 때의 이미지들입니다. 우리가 만든 숫자 인식 인공지능은 고딕으로된 숫자 이미지 뿐만 아니라 여러가지 폰트로 된 숫자도 인식해야 하니까 이렇게 만들어두면 학습에 도움이 되겠죠? 우리는 이런 이미지들을 잘 학습시켜서 우리가 어떤 이미지를 넣더라도 정확한 숫자를 출력하게 하고 싶습니다.
이미지를 학습 시키는 것에 앞서 이미지를 학습에 사용될 데이터로 전환하는 작업을 해야합니다. 아마 비전공자라 하더라도 컴퓨터 이미지들은 픽셀의 형태로 이뤄져 있다는 것을 아실겁니다. 수십에서 수만개의 촘촘한 점들에 숫자 값을 대입해서 이 위치에는 검은 색을 또는 다른 위치에는 빨간 색을 표현하는것이 컴퓨터가 이미지를 보여주는 방식입니다.
왼쪽 '1' 그림을 픽셀 단위로 표현 한 것입니다.
이렇게 표현하니 거대한 행렬이라고 볼 수 있을 것 같습니다. MNIST에 있는 모든 데이터는 28x28 단위를 따르고 있습니다. 784개의 Element들을 하나로 쭉 나열하면 각 이미지가 서로 독립적이게 만들어 줄 뿐만 아니라 보기에도 훨씬 편할 것 같네요(Flattening 한다고 합니다). 이렇게 변환한 행렬은 학습 데이터로 사용하기에 매우 편리합니다.
실제 코드 상에서는 아예 784 배열로 변환한 데이터 값 자체를 한번에 받습니다.
# Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
학습에 사용할 input data를 만들어뒀으니 이제 학습 모델을 만들어 봅시다. 숫자 인식은 분류(Classification)작업으로 볼 수 있습니다. 그래서 여기선 Logistic Regression을 사용해야 합니다. 784개의 Element들(숫자 이미지에 있는 모든 픽셀)을 보고 특성을 분석해 10개의 분류(0~9 까지 인식해야 합니다)로 만들어주는 인공지능을 만드는 것이 목표입니다. 먼저 Logistic Regression 에 수식을 다시 살펴봅시다.
여기서 우리는 z에 해당하는 수식을 만들어야 합니다. 전에는 쎄타0, 쎄타1 만 만들어서 간단히 했었죠? 그런데 이번에는 쎄타0, 쎄타1 뿐만 아니라 쎄타784 까지 만들어야 합니다. 학습에 사용하는 input 데이터의 element가 총 784개이기 때문이지요. 계산량이 무척 많아지겠지만 그래도 컴퓨터가 대신 해줄 것이니 너무 염려하지 않도록 합시다.
output data는 크기가 10인 배열로 둘겁니다. 결과값이 3이라면 (0, 0, 0, 1, 0, 0, 0, 0, 0, 0) 요렇게 쉽게 표시 할 수 있습니다. 물론 two's complement로 더 공간 효율적으로 할 수 있긴 하지만 그렇게는 안해요. 이렇게하면 벡터로 표시 할 때 훨씬 보기 편하거든요.
입력값(X)이 1x784 로 들어 온다면 우리는 출력값(Y)을 1x10으로 내야 스펙에 맞습니다. 그러면 입력값을 처리하는 행렬(W)의 크기는 784x10이 되어야 합니다. 수학 수식으로 표현해보면 Y(출력값) = X(입력값) * W(쎄타들의 모음) + B(바이어스, 일반 상수에 해당하는 값) 로 볼 수 있습니다.
실제 코드에서도 이렇게 표현 합니다.
# Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) # y = xW + b y = tf.matmul(x, W) + b
- x는 인풋 데이터 형식입니다. C언어로 생각하면 int, char을 설정해주는 것과 비슷해요.
- W는 쎄타들의 모음입니다. 실제로는 Weight라고 불러요. Input이 784개이고 10개의 output이 있으니 총 7840개의 Weight가 존재합니다. 지금은 모두 0으로 세팅했는데 여러번 최적화 과정을 통해서 적절한 값을 찾아가게 될겁니다.
- b는 Bias 값입니다. 일반상수에 해당하는 값이에요
- y를 matmul 명령어를 이용해 정의합니다. matmul은 벡터 값의 곱을 의미합니다. 결과적으로 y = xW + b 로 표현이 되겠네요.
학습 모델까지 훌륭하게 만들었습니다. 이제 Cost function을 구하고 최적화 작업만 거치면 됩니다. 이 작업은 softmax와 cross entropy로 한방에 해결 할 수 있습니다. softmax 포스팅에서 소개한 코드를 그대로 사용하겠습니다.
cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.reduce_mean은 argument의 평균을 구해주는 작업입니다. 이렇게 cost function 구하는 모델을 만들어 두고 바로 밑에 GradientDescentOptimizer 의 minimize argument안에 cross_entropy 모델까지 넣으면 Gradient Decent 방식으로 최적화 할 수 있는 모델 까지 만들어 집니다.
하지만 이렇게 그냥 둔다고 바로 알아서 학습을 하진 않습니다. 지금까지 우리가 한 것 학습을 위한 모델을 만든 것에 불과하니까요. 함수를 정의한 후에는 호출을 해야하는 것처럼 여기도 우리가 만든 학습 모델을 실행하는 코드가 필요합니다.
# Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # Test trained model correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
주석으로 추천하건데 # Train 아래에 있는 부분이 학습하는 부분 인것 같지요? batch_xs, batch_ys는 학습에 쓰일 이미지 벡터 값과 실제 숫자 벡터 값을 가지고 있습니다. 그리고 바로 아래 sess.run 안의 feed_dict 안에 넣어놓지요. 첫번째 인자인 train_step은 우리가 최적화 작업까지 선언한 학습 모델입니다. 이 두 인자만 넣어주면 바로 tensorflow에 있는 함수들을 이용해서 학습이 진행됩니다.
바로 밑에의 코드는 학습한 데이터를 테스트해보는 작업이네요. correct_prediction은 예측한 값(y)과 실제 값(y_)이 동일한지를 보는 것이고 accuracy는 예측 한 값들에서 평균을 내는 작업입니다. 이것도 어떻게보면 검증하는 '모델'로 볼 수 있겠네요. 첫번째 인자로 예측 모델(accuracy)를 두고 두번째 인자에는 테스트할 이미지와(mnist.test.image)과 결과값(mnist.test.labels)을 대입해서 정확도를 측정 할 수 있습니다.
위 학습 모델로 정확도가 92% 정도 나옵니다. 생각보다 우수하죠? 그런데 딥러닝을 적용하면 정확도가 97%까지 상승한다고 합니다!
참고자료
- https://www.tensorflow.org/get_started/mnist/beginners