http://aikorea.org/blog/rnn-tutorial-1/
RNN이란?
RNN에 대한 기본적인 아이디어는 순차적인 정보를 처리한다는 데 있다.
기존의 신경망 구조에서는 모든 입력(과 출력)이 각각 독립적이라고 가정했지만, 많은 경우에 이는 옳지 않은 방법이다.
한 예로, 문장에서 다음에 나올 단어를 추측하고 싶다면 이전에 나온 단어들을 아는 것이 큰 도움이 될 것이다.
RNN이 recurrent하다고 불리는 이유는 동일한 태스크를 한 시퀀스의 모든 요소마다 적용하고, 출력 결과는 이전의 계산 결과에 영향을 받기 때문이다. 다른 방식으로 생각해 보자면, RNN은 현재지 계산된 결과에 대한 "메모리" 정보를 갖고 있다고 볼 수도 있다. 이론적으로 RNN은 임의의 길이의 시퀀스 정보를 처리할 수 있지만, 실제로는 비교적 짧은 시퀀스만 효과적으로 처리할 수 있다
일반적인 RNN 구조는 다음과 같이 생겼다.
위 그림에서 RNN의 recurrent한 연결이 펼쳐진 것을 볼 수 있다.
RNN 네트워크를 "펼친다"는 말은 간단히 말해서 네트워크를 전체 시퀀스에 대해 그려놓았다고 보면 된다.
즉, 우리가 관심있는 시퀀스 정보가 5개의 단어로 이루어진 문장이라면, RNN 네트워크는 한 단어당 하나의 layer씩 (recurrent 연결이 없는, 또는 사이클이 없는) 5-layer 신경망 구조로 펼쳐질 것이다.
RNN 구조에서 일어나는 계산에 대한 식은 아래와 같다.
- 는 시간 스텝(time step) 에서의 입력값이다.
- 는 시간 스텝 에서의 hidden state이다.
- 는 시간 스텝 에서의 출력값이다.
- 예를 들어, 문장에서 다음 단어를 추측하고 싶다면 단어 수만큼의 차원의 확률 벡터가 될 것이다.
몇 가지 짚어두고 넘어갈 점이 있다.
- Hidden state 는 네트워크의 메모리라고 생각할 수 있다.
- 는 과거의 시간 스텝들에서 일어난 일들에 대한 정보를 전부 담고 있고,
- 출력값는 오로지 현재 시간 스텝의 메모리에만 의존한다.
- 하지만 위에서 잠깐 언급했듯이, 실제 구현에서는 너무 먼 과거에 일어난 일들은 잘 기억하지 못한다.
- 각 layer마다의 파라미터 값들이 전부 다 다른 기존의 deep한 신경망 구조와 달리, RNN은 모든 시간 스텝에 대해 파라미터 값을 전부 공유하고 있다 (위 그림의 U, V, W).
- 이는 RNN이 각 스텝마다 입력값만 다를 뿐 거의 똑같은 계산을 하고 있다는 것을 보여준다.
- 이는 학습해야 하는 파라미터 수를 많이 줄여준다.
- 위 다이어그램에서는 매 시간 스텝마다 출력값을 내지만, 문제에 따라 달라질 수도 있다.
- 예를 들어, 문장에서 긍정/부정적인 감정을 추측하고 싶다면 굳이 모든 단어 위치에 대해 추측값을 내지 않고 최종 추측값 하나만 내서 판단하는 것이 더 유용할 수도 있다.
- 마찬가지로, 입력값 역시 매 시간 스텝마다 꼭 다 필요한 것은 아니다.
- RNN에서의 핵심은 시퀀스 정보에 대해 어떠한 정보를 추출해 주는 hidden state이기 때문이다
https://ratsgo.github.io/natural%20language%20processing/2017/03/09/rnnlstm/
https://brunch.co.kr/@chris-song/9
시퀀스 길이에 관계없이 인풋과 아웃풋을 받아들일 수 있는 네트워크 구조이기 때문에 필요에 따라 다양하고 유연하게 구조를 만들 수 있다는 점이 RNN의 가장 큰 장점입니다.
RNN의 기본 구조는 위 그림과 같습니다.
녹색 박스는 히든 state를 의미합니다. 빨간 박스는 인풋x, 파란 박스는 아웃풋y입니다.
현재 상태의 히든 state ht는 직전 시점의 히든 state ht−1를 받아 갱신됩니다.
현재 상태의 아웃풋yt는ht를 전달받아 갱신되는 구조입니다.
수식에서도 알 수 있듯 히든 state의활성함수(activation function)은비선형 함수인하이퍼볼릭탄젠트(tanh)입니다.
RNN의 기본 구조
RNN의 기본 동작을 직관적으로 이해해 보기 위해 CS231n 강좌의 Kapathy갓파시가 든 예제를 가져와 봤습니다.
어떤 글자가 주어졌을 때 바로 다음 글자를 예측하는 character-level-model을 만든다고 칩시다.
예컨대 RNN 모델에 ‘hell’을 넣으면 ‘o’를 반환하게 해 결과적으로는 ‘hello’를 출력하게 만들고 싶은 겁니다.
우선 우리가 가진 학습데이터의 글자는 ‘h’, ‘e’, ‘l’, ‘o’ 네 개뿐입니다.
이를one-hot-vector로 바꾸면 각각[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]이 됩니다.
x1은[1,0,0,0]입니다.
이를 기반으로h1인[0.3,−0.1,0.9]를 만들었습니다(h0는 존재하지 않기 때문에 랜덤 값을 집어넣습니다).
이를 바탕으로y1인[1.0,2.2,−3.0,4.1]로 생성했습니다.
마찬가지로 두번째, 세번째, 네번째 단계들도 모두 갱신하게 됩니다.
이 과정을순전파(foward propagation)라고 부릅니다.
다른 인공신경망과 마찬가지로 RNN도 정답을 필요로 합니다.
모델에 정답을 알려줘야 모델이parameter를 적절히 갱신해 나가겠죠.
이 경우엔 바로 다음 글자가 정답이 되겠네요.
예컨대 ‘h’의 다음 정답은 ‘e’, ‘e’ 다음은 ‘l’, ‘l’ 다음은 ‘l’, ‘l’ 다음은 ‘o’가 정답입니다.
위의 그림을 기준으로 설명을 드리면 첫번째 정답인 ‘e’는 두번째 요소만 1이고 나머지가 0인 one-hot-vector입니다.
그림을 보면 아웃풋에 진한 녹색으로 표시된 숫자들이 있는데 정답에 해당하는 인덱스를 의미합니다.
이 정보를 바탕으로역전파(backpropagation)를 수행해 parameter값들을 갱신해 나갑니다.
그렇다면 RNN이 학습하는 parameter는 무엇일까요?
인풋x를 히든레이어h로 보내는W_xh,
이전 히든레이어hh에서 다음 히든레이어h로 보내는W_hh,
히든레이어h에서 아웃풋y로 보내는W_hy가 바로 parameter입니다.
그리고 모든 시점의 state에서 이 parameter는 동일하게 적용됩니다(shared weights).
RNN의 순전파
앞장에서 말씀드린 RNN의 기본 구조를 토대로 forward compute pass를 아래와 같이 그려봤습니다. 위에서 설명한 수식을 그래프로 옮겨놓은 것일 뿐입니다.
RNN의 역전파
자, 이제 backward pass를 볼까요? 아래 그림과 같습니다. 혹시 역전파가 생소하신 분은이곳을 참고하시기 바랍니다.
위 움짤과 아래 그림은 같은 내용입니다. 우선 forward pass를 따라 최종 출력되는 결과는yt입니다.
최종 Loss에 대한yt의 그래디언트(dL/dyt)가 RNN의 역전파 연산에서 가장 먼저 등장합니다.
이를 편의상dyt라고 표기했고, 순전파 결과yt와 대비해 붉은색으로 표시했습니다. 앞으로 이 표기를 따를 예정입니다.
dyt는 덧셈 그래프를 타고 양방향에 모두 그대로 분배가 됩니다.
dWhy= ht X dyt
dht는 흘러들어온 그래디언트dyt에Why를 곱한 값입니다.
dhraw는 흘러들어온 그래디언트인dht에 로컬 그래디언트인1−tanh2(ht)을 곱해 구합니다. 나머지도 동일한 방식으로 구합니다.
다만 아래 그림에 주의할 필요가 있습니다.
RNN은 히든 노드가 순환 구조를 띄는 신경망입니다.
즉ht를 만들 때ht−1가 반영됩니다.
바꿔 말하면 아래 그림의dht−1은 t-1 시점의 Loss에서 흘러들어온 그래디언트인Why∗dyt−1뿐 아니라 ★에 해당하는 그래디언트 또한 더해져 동시에 반영된다는 뜻입니다.
LSTM의 기본 구조
RNN은 관련 정보와 그 정보를 사용하는 지점 사이 거리가 멀 경우 역전파시 그래디언트가 점차 줄어 학습능력이 크게 저하되는 것으로 알려져 있습니다. 이를vanishing gradient problem이라고 합니다.
이 문제를 극복하기 위해서 고안된 것이 바로 LSTM입니다.
LSTM은 RNN의 히든 state에 cell-state를 추가한 구조입니다.
LSTM을 가장 쉽게 시각화한포스트를 기본으로 해서 설명을 이어나가겠습니다.
cell state는 일종의 컨베이어 벨트 역할을 합니다. 덕분에 state가 꽤 오래 경과하더라도 그래디언트가 비교적 전파가 잘 되게 됩니다. LSTM 셀의 수식은 아래와 같습니다. ⊙는 요소별 곱셈을 뜻하는 Hadamard product 연산자입니다.
$$ftitotgtctht=σ(Wxh_fxt+Whh_fht−1+bh_f)=σ(Wxh_ixt+Whh_iht−1+bh_i)=σ(Wxh_oxt+Whh_oht−1+bh_o)=tanh(Wxh_gxt+Whh_ght−1+bh_g)=ft⊙ct−1+it⊙gt=ot⊙tanh(ct)$$
forget gateftft는 ‘과거 정보를 잊기’를 위한 게이트입니다.ht−1ht−1과xtxt를 받아 시그모이드를 취해준 값이 바로 forget gate가 내보내는 값이 됩니다. 시그모이드 함수의 출력 범위는 0에서 1 사이이기 때문에 그 값이 0이라면 이전 상태의 정보는 잊고, 1이라면 이전 상태의 정보를 온전히 기억하게 됩니다.
input gateit⊙gtit⊙gt는 ‘현재 정보를 기억하기’ 위한 게이트입니다.ht−1ht−1과xtxt를 받아 시그모이드를 취하고, 또 같은 입력으로 하이퍼볼릭탄젠트를 취해준 다음 Hadamard product 연산을 한 값이 바로 input gate가 내보내는 값이 됩니다. 개인적으로itit의 범위는 0~1,gtgt의 범위는 -1~1이기 때문에 각각 강도와 방향을 나타낸다고 이해했습니다.
LSTM의 순전파
LSTM 순전파는 아래와 같습니다.
여기서 주목해야 할 점은HtHt입니다. 이 행렬을 행 기준으로 4등분해i,f,o,gi,f,o,g각각에 해당하는 활성함수를 적용하는 방식으로i,f,o,gi,f,o,g를 계산합니다. (물론 이렇게 계산하지 않고 다른 방식을 써도 관계는 없습니다) 이를 그림으로 나타내면 다음과 같습니다.
LSTM의 역전파
그럼 이제 LSTM의 역전파를 알아볼까요? 아래 움짤과 같습니다.
이제부터 나열한 그림은 위 움짤과 내용이 같습니다. 우선dft,dit,dgt,dotdft,dit,dgt,dot를 구하기까지 backward pass는 RNN과 유사합니다.
dHtdHt를 구하는 과정이 LSTM backward pass 핵심이라고 할 수 있죠.HtHt는it,ft,ot,gtit,ft,ot,gt로 구성된 행렬입니다. 바꿔 말하면 각각에 해당하는 그래디언트를 이를 합치면(merge)dHtdHt를 만들 수 있다는 뜻입니다.i,f,oi,f,o의 활성함수는 시그모이드이고,gg만 하이퍼볼릭탄젠트입니다. 각각의 활성함수에 대한 로컬 그래디언트를 구해 흘러들어온 그래디언트를 곱해주면 됩니다.
순전파 과정에서HtHt를 4등분해it,ft,ot,gtit,ft,ot,gt를 구했던 것처럼, backward pass에서는di,df,do,dgdi,df,do,dg를 다시 합쳐dHtdHt를 만듭니다. 이렇게 구한dHtdHt는 다시 RNN과 같은 방식으로 역전파가 되는 구조입니다.
LSTM은 cell state와 히든 state가 재귀적으로 구해지는 네트워크입니다. 따라서 cell state의 그래디언트와 히든 state의 그래디언트는 직전 시점의 그래디언트 값에 영향을 받습니다. 이는 RNN과 마찬가지입니다. 이를 역전파시 반영해야 합니다.