개발자의시작

[Pytorch] 11-0 RNN intro 본문

머신러닝(machine learning)

[Pytorch] 11-0 RNN intro

LNLP 2022. 3. 3. 19:11

이 글은 모두를위한딥러닝 시즌2 https://github.com/deeplearningzerotoall/PyTorch 을 정리한 글입니다.

 

GitHub - deeplearningzerotoall/PyTorch: Deep Learning Zero to All - Pytorch

Deep Learning Zero to All - Pytorch. Contribute to deeplearningzerotoall/PyTorch development by creating an account on GitHub.

github.com

 

RNN의 구조를 살펴보고 RNN으로 해결할 수 있는 문제들에 대해 소개한다.

 

RNN은 sequential data를 잘 다루기 위해 고안되었다. sequential data는 데이터의 값뿐만 아니라 데이터의 순서도 중요한 의미를 가지는 데이터를 말한다. "데이터의 순서가 중요하다"라는 것은 예를 들어 "hello"라는 단어가 있을 때 각 알파벳의 순서는 단어의 의미를 형성하는데 굉장히 중요한 역할을 한다. 이와 같이 데이터의 순서도 그 데이터의 일부인 것을 sequential data라고 한다. 

neural network에서도 sequential data를 다룰 수 있는 방법이 있다. 입력되는 데이터 벡터가 있다고 할 때 메트릭스가 주어지고 연산이 반복되면서 학습이 이루어진다. 이때 데이터에 추가로 position index라는 것을 추가해 줄 수 있는데, position index는 이 벡터가 몇 번째 벡터인지 정보를 담고 있는 index가 된다. 그렇다면 메트릭스에도 position index를 처리하기 위한 dimension이 추가될 것이고, nn 입장에서는 학습하는 layer가 position index 차원으로 입력되는 위치 정보를 받아서 학습을 할 수 있게 된다. 하지만 이런 neural network의 연산 만으로는 사람의 언어와 같이 순서가 중요하면서도 복잡한 구조를 모델이 파악하는 것이 쉽지 않다. 그래서 RNN에서는 position index 대신에 "입력하는 데이터의 순서를 어떻게 하면 모델이 잘 이해할 수 있을까"를 중심으로 설계된다. 보통은 오른쪽 그림과 같이 나타내는데, A는 cell을 나타내며 이 cell에 t번째 입력값이 들어가게 되면 t번째 출력 값이 나오게 되고, 동시에 다른 출력 값이 나오게 되어 다시 cell A로 들어가는 구조이다. 

오른쪽 그림의 roop를 펼친 것이 왼쪽 그림이다. 첫 번째 입력값이 cell A에서 처리가 되고 한 가지가 h0로 출력되고 출력되지 않고 다른 가지로 다음 cell로 전달된다. 이렇게 출력되지 않고 다음 cell로 전달되는 가지를 hidden state라고 한다. hidden state라고 부르는 이유는 RNN 바깥에서는 보이지 않기 때문이다. 이렇게 설계하면 두 번째 세 번째 출력 값들은 이전의 처리 결과들을 반영할 수 있다. 그래서 모델이 데이터의 순서를 이해할 수 있게 된다. 예를 들어 "hello"를 입력받고 각 단어를 입력할 때마다 이 단어가 어떤 문자로 구성되는지 다음 문자를 예측하는 문제를 가정한다. "he"까지 입력된다면 그 시점에서 다음에 올 문자인 "l"을 출력하는 것이다. 각 문자가 들어올 시점부터 각 cell의 출력은 "e", "l", "l", "o" 그리고 "o"가 나오면 출력이 끝났다는 것 또한 알려줄 수 있다. 이 문제의 어려운 점 중 하나는 같은 입력 "l"을 받았을 때 어떤 때는 "l"을 출력하고 어떤 때는 "o"를 출력한다는 점이다. 근데 RNN에서는 이전 어떤 문자가 나왔다는 사실을 hidden state를 통해 전달받기 때문에 모델이 마치 순서를 이해한 것처럼 정확하게 다음 문자를 출력할 수 있다.

 

또한 RNN은 모든 cell이 파라미터를 공유한다. A가 하나라는 뜻으로, 긴 sequence가 들어와도 이를 처리하기 위한 cell은 A 하나이다. 즉, 입력되는 단어가 "hello"가 되던 또는 엄청 긴 단어가 들어온다 하더라도 cell A에 들어가는 파라미터만 알고 있으면 언제든 다음 단어를 예측하는 모델을 정상적으로 작동할 수 있다. 

cell A에서 일어나는 일은 cell을 어떻게 설계하냐에 따라 복잡도가 완전히 달라지는데, 기본적으로는 함수 연산이다. 이 함수는 기본적으로 이전 단계의 hidden state와 지금 단계에서의 입력값을 가지고 함수 연산을 통해 어떤 출력 값 ht를 만드는 것이다. 이러한 연산은 굉장히 많지만 예를 하나 들면 이전 스텝에서의 hidden state ht-1에 어떤 매트릭스 weight Wh를 곱해주고, 이번 스텝에서의 입력값 xt에 마찬가지로 매트릭스 weight Wx를 곱해주고 그 결과를 더한 다음 tanh이라는 activation funciton을 통해 출력한다. 이 외에도 설계할 수 있는 방법은 매우 다양하다. 일반적인 방법으로는 LSTM이나 GRU 등이 있다. cell A의 파라미터들은 학습의 대상이기 때문에 복잡해지면 복잡해질수록 이 셀이 학습되는 정도 trainability는 감소하는 것으로 알려져 있다. 다시 말해 복잡한 cell을 쓰면 같은 수준의 학습에서는 좋은 성능을 낼 수 있지만, 그 학습 수준에 도달하기까지에는 더 많은 자원이 필요하다는 것을 의미한다. 그리고 이 cell의 복잡도는 일반적인 RNN이 가장 낮고 LSTM이 가장 높고 GRU가 중간 정도이다. 

 

지금까지는 하나의 cell의 설계를 살펴봤다. 여기서부터는 RNN에 데이터를 어떤 형식으로 입력하고, 나온 출력 값들 중에서 무엇을 취하냐에 따라 RNN이 굉장히 다양한 task에 적용할 수 있다는 것을 알 수 있다. 먼저 가장 왼쪽에 있는 것은 RNN이 아닌, 일반적인 neural network이다. 이 외에 나머지가 RNN이다. 

 

1) one to many는 입력값은 하나이고 출력 값은 여러 개다. 하나의 입력에 대해서 여러 개가 출력되는 것인데, 예를 들면 하나의 이미지가 들어가고 여러 단어들의 sequence, 즉 문장이 나온다라고 할 수 있다. "이미지에서 어떤 일이 일어나고 있는지 자막 같은 것이 달린다"라고 볼 수 있고, image captioning task에 이런 형태로 RNN이 적용될 수 있다.

 

2) many to one은 여러 개의 입력값이 있고 출력 값이 하나인 형태이다. 위의 예를 이어 보면, 여러 개의 단어들이 입력되고, 즉 문장이 입력된다라고 가정할 수 있다. 문장이 입력되고 하나의 값이 나오는데, 이 값은 무엇이든 될 수 있는데 예를 들어 감정에 대한 label을 얘기할 수 있다. 감정 label을 출력하면 이 문장이 어떤 감정을 가지고 있다는 것을 분석해주는 sentiment analysis 같은 task에 이런 형태의 RNN을 적용할 수 있다.

 

3-1) many to many는 문장이 들어오고 문장이 출력되는 형태이다. many to many 두 개의 형태가 있다. 다른 지점이 출력되는 지점이 다르다. 그림의 왼쪽에서 앞의 두 칸과 뒤의 두 칸이 비어있는데, 사실 출력 값은 존재하지만, 쓰지 않는 것이다. 이 경우 문장이 입력으로 사용되는데 이 문장이 다 끝난 시점부터 다음 문장이 시작된다. 이런 경우 보통 번역 task에 잘 사용된다. 한 문장이 다 들어오고 이 문장을 다 본 다음에 새롭게 문장을 쓰는 시작하는 것이다. "한국말은 끝까지 들어봐야 안다"라는 말처럼 문장을 끝까지 봐야지만 제대로 된 번역이 이루어질 수 있기 때문에 이런 구조로 설계가 되어있는 RNN도 있다.

 

3-2) many to many는 그림의 마지막처럼 여러 개의 input이 있고, 입력이 들어갈 때마다 새로 output들이 나오는 형태 또한 가능하다. 이런 형태는 문장이 입력으로 주어지고 문장을 구성하는 각 단어에 대한 tagging 작업을 수행할 때 주로 사용하는 형태이다. 

'머신러닝(machine learning)' 카테고리의 다른 글

[Pytorch] 11-2 RNN hihello and charseq  (0) 2022.03.14
[Pytorch] 11-1 RNN basics  (0) 2022.03.03
[Pytorch] 09-4 Batch Normalization  (0) 2022.03.03
[Pytorch] 09-3 Dropout  (0) 2022.01.17
[Pytorch] 09-2 Weight initialization  (0) 2022.01.17
Comments