TigerCow.Door


안녕하세요. 문범우입니다.

이번 포스팅에서는 RNN을 이용하여 hihello를 학습시켜 보도록 하겠습니다.



1. 'hihello' 학습시키기

이번에는 위와 같이 우리가 hihello 라는 문자열을 주었을 때, 각 문자에 대해 다음 문자를 예측해보도록 학습시킬 것 입니다.

이 문제가 간단해보일 수 있지만, 좀 더 자세히 살펴보면 h를 입력했을 때, 어쩔때는 i를, 어쩔 때는 e를 반환해야 합니다. 이는 RNN의 특성인, 이전 문자가 무엇이 나왔는지 알아야 값을 제대로 출력할 수 있습니다.



2. RNN basic 정리


그럼 먼저 간단히 RNN의 기본적인 내용에 대해 정리해보도록 하겠습니다.

우리가 그동한 RNN을 공부하면서 그 자체에 대해서는 간단히 여겨질 수 있으나 실제로 RNN에 어떤 값을 어떻게 넣어줘야 할지 복잡할 수 있습니다.

하나씩 살펴보겠습니다.

우리가 넣고자 하는 text는 'hihello'입니다. 그리고 그 문자에는 5개의 유니크한 문자로 존재합니다. 그리고 이 문자열을 인덱스로 표현하면,

h:0 , i:1 , e:2 , l:3 , o:4 이고 이를 one-hot 인코딩으로 나타내면 위 그림과 같습니다.


우리가 전체적으로 구성할 모델은 위와 같을 것입니다.

그럼, 이제 입력 dimension, 출력 dimension, batch size 등에 대해서 생각해보도록 하겠습니다.

그림에서 보듯이 input dimension 은 5입니다.

그리고 총 6개로 구성되기 때문에 sequence = 6 입니다.

또한 hidden 또한 5 입니다.

마지막으로 batch size는 하나의 문자이므로 1 입니다.



3. RNN 만들기


먼저 rnn cell을 만듭니다.

이때 기본적으로 BasicRNNCell 을 이용할 수도 있겠지만, 많이 사용되는 LSTM 이나 GRU를 이용할 수도 있습니다. 이때 중요한 것은 rnn_size입니다.

rnn size는 출력값으로써 5로 정해집니다.


두번째로 중요한 것은 입력입니다.

이는 위에서 알아봤던 것처럼, input dimension = 5, sequence = 6 으로 넣어주면 될 것입니다.


그럼 이를 바탕으로 데이터를 만들어 보도록 하겠습니다.


먼저 입력, x_data 로 hihell을 인덱스로써 넣어주었고 이를 one_hot encoding으로 변환하여, x_one_hot 으로 구성하였습니다.

그리고 우리가 학습하고자 하는 결과값, y_data 또한 위와 같이 구성하였습니다.


그리고 X 와 Y에 우리가 설정한 값들을 넣어줍니다.

위에서 알아본 것과 같이 sequence_length = 6 이며, input_dim = 5 입니다.

이후, cell을 만들고 이때 결과의 크기, hidden_size를 입력합니다.

또한 위에서는 initial_state를 만드는데 전부 0으로 만들었습니다.

그리고 우리가 만든 cell과 state를 이용하여 결과를 내도록 하였습니다.


이렇게 만든 모델이 얼마나 잘 맞는지 알기 위해서 loss를 구해야 합니다.

즉, cost를 구해야하는데 텐서플로우에서는 이를 쉽게할 수 있도록 sequence_loss 라는 함수를 제공합니다.


sequence_loss 함수를 활용하여 위와 같이 구성합니다.

우리의 모델을 통해 나오는 outputs를 sequence_loss의 logits으로 넣어줍니다. 이는 우리의 예측값을 넣어주는 것 입니다.

그리고 targets는 우리의 훈련값을 넣어주는 것으로써 Y를 넣어줍니다. 그리고 여기서 weights는 단순히 1인 값으로 넣어주도록 합니다.


그리고 이것을 평균을 내서 AdamOptimize로 넣어줌으로써 loss를 최소화 시켜줍니다.



그리고 학습과정은 그 동안했던 것과 유사하게 진행해줍니다.



이를 통한 결과는 위와 같습니다.

결과의 초반에 보면 예측이 아주 잘못되고 있지만 시간이 지나면서 loss가 점점떨어지고, 예측또한 잘 되는 것을 볼 수 있습니다.


블로그 이미지

Tigercow.Door

Web Programming / Back-end / Database / AI / Algorithm / DeepLearning / etc