오늘은 Python으로 RNN을 복습해볼까 합니다. 인공지능을 공부하는 사람이라면 모를 수가 없는데요. 학부 시절 풀었던 과제를 바탕으로 포스팅하겠습니다.

이 내용은 광운대학교 전기공학과 인공지능응용 수업 과제입니다. 혹시라도 후배님들이 이 글을 보게 된다면 비밀로 하고 과제를 진행해주세요.

RNN (Recurrent Neural Network) 설명

1. 핵심 개념

RNN(Recurrent Neural Network)은 순차 데이터(sequence data)를 처리하기 위해 설계된 신경망입니다.

핵심 아이디어는 다음과 같습니다.

이전 정보를 기억하면서 현재 입력을 처리한다.

일반적인 신경망은 입력이 서로 독립적이라고 가정하지만,

RNN은 데이터 간의 순서와 시간적 관계를 고려합니다.


2. 왜 RNN이 필요한가

다음과 같은 데이터는 순서 정보가 중요합니다.

  • 문장 (자연어 처리)

  • 음성 신호

  • 시계열 데이터 (주가, 센서 데이터 등)

예를 들어 문장에서:

“I am happy” 와 “I am not happy”는

단어 “not” 하나로 의미가 완전히 달라집니다.

이처럼 이전 정보가 중요하기 때문에 RNN이 사용됩니다.


3. RNN의 기본 구조

RNN은 다음과 같은 구조를 가집니다.

[ h_t = f(W_x x_t + W_h h_{t-1} + b) ]

[ y_t = W_y h_t ]

  • ( x_t ): 현재 입력
  • ( h_t ): 현재 hidden state
  • ( h_{t-1} ): 이전 hidden state
  • ( y_t ): 출력
  • ( f ): 활성화 함수 (tanh, ReLU 등)

4. 동작 원리

RNN은 시간 흐름에 따라 반복적으로 계산됩니다.

단계별 과정

  1. 초기 hidden state 설정 (보통 0)

  2. 첫 번째 입력 ( x_1 ) 처리 → ( h_1 )

  3. 두 번째 입력 ( x_2 ) 처리 → ( h_2 )

  4. 이전 hidden state를 계속 전달

  5. 마지막까지 반복

즉, hidden state가 일종의 “기억 역할”을 수행합니다.


5. Unrolling (펼쳐보기)

RNN은 내부적으로 같은 구조가 시간축 방향으로 반복됩니다.

이를 “unrolling”이라고 하며, 다음과 같이 볼 수 있습니다.

[ h_1 \rightarrow h_2 \rightarrow h_3 \rightarrow \cdots \rightarrow h_T ]

각 단계에서 동일한 가중치가 사용됩니다.


6. 주요 특징

(1) Memory (기억)

이전 정보를 hidden state에 저장하여 활용합니다.


(2) Weight Sharing

모든 시간 단계에서 동일한 가중치를 사용합니다.
→ 파라미터 수 감소


(3) Sequence Modeling

입력의 순서와 시간적 의존성을 반영할 수 있습니다.


7. 한계점

(1) Vanishing Gradient

시간이 길어질수록 gradient가 작아져서 초기 정보가 사라지는 문제가 발생합니다.


(2) Long-term Dependency 문제

멀리 떨어진 정보 간의 관계를 학습하기 어렵습니다.


(3) 느린 연산 속도

순차적으로 계산되기 때문에 병렬 처리가 어렵습니다.


8. 개선된 모델

RNN의 한계를 해결하기 위해 다음과 같은 모델이 등장했습니다.

LSTM (Long Short-Term Memory)

  • 장기 의존성 문제 해결

  • Gate 구조 (input, forget, output)


GRU (Gated Recurrent Unit)

  • LSTM보다 단순한 구조

  • 비슷한 성능, 더 빠른 학습


9. 활용 분야

  • 자연어 처리 (번역, 감정 분석)

  • 음성 인식

  • 시계열 예측

  • 음악 생성


10. 정리

RNN은 순차 데이터에서 이전 정보를 활용하여 현재를 예측하는 신경망입니다.

하지만 긴 시퀀스에서는 한계가 있으며, 이를 해결하기 위해 LSTM, GRU, Transformer 등이 사용됩니다.

실습

다음 미완성 코드를 활용해 좀 더 긴 문장을 학습해보자

  • Sample sentences

    • “if you want to build a ship, don’t drum up people together to “

    • “collect wood and don’t assign them tasks and work, but rather “

    • “teach them to long for the endless immensity of the sea.”

  • Training data sentence

    • Shape: (N, S, E)
  • Hidden

    • Shape: (N, S, E *2)
  • Output

    • Shape: (N, S, E)
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# for reproducibility
torch.manual_seed(100)

# Dictionary
sample_sentence_1 = "if you want to build a ship, don't drum up people together to "
sample_sentence_2 = "collect wood and don't assign them tasks and work, but rather "
sample_sentence_3 = "teach them to long for the endless immensity of the sea."
sample_sentence = sample_sentence_1 + sample_sentence_2 + sample_sentence_3
char_set = list(set(sample_sentence))
dic = {c: i for i, c in enumerate(char_set)}

# Parameters
dic_size = len(dic)
input_size = dic_size
hidden_size = dic_size * 2
output_size = dic_size
unit_sequence_length = 20

input_batch = []
target_batch = []

for i in range(0, len(sample_sentence) - unit_sequence_length):
    input_sequence = sample_sentence[i:i + unit_sequence_length]
    target_sequence = sample_sentence[i + 1:i + unit_sequence_length + 1]
    input_batch.append([dic[char] for char in input_sequence])
    target_batch.append([dic[char] for char in target_sequence])

# input_batch와 target_batch 크기 확인
print(f"Input Batch Size: {len(input_batch)} x {len(input_batch[0])}")
print(f"Target Batch Size: {len(target_batch)} x {len(target_batch[0])}")

batch_size = len(input_batch)
sequence_length = len(input_batch[0])  # unit_sequence_length
input_size = len(dic)

# Create one-hot encoded input
X = torch.zeros(batch_size, sequence_length, input_size)  # Initialize with zeros
for i, seq in enumerate(input_batch):
    for j, char_index in enumerate(seq):
        X[i, j, char_index] = 1  # Set the corresponding element to 1

# Convert target to LongTensor
Y = torch.LongTensor(target_batch)  # (batch_size, sequence_length)

# Train and Evaluate
rnn_types = ["RNN", "LSTM", "GRU"]
learning_rate = 0.05
training_epochs = 100
criterion = nn.CrossEntropyLoss()


class Custom_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, rnn_type="RNN"):
        super(Custom_RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_type = rnn_type

        if rnn_type == "RNN":
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == "LSTM":
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        elif rnn_type == "GRU":
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Invalid RNN type")

        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_state):
        if self.rnn_type == "LSTM":
            h, c = hidden_state
            out, (h, c) = self.rnn(x, (h, c))
            hidden_state = (h, c)
        else:
            out, h = self.rnn(x, hidden_state)
            hidden_state = h

        out = self.fc(out)
        return out, hidden_state


# Initialize Model, Optimizer, Loss
model = Custom_RNN(input_size, hidden_size, output_size, num_layers=1, rnn_type="LSTM")
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
# Training Loop
for rnn_type in rnn_types:
    print(f"\nTraining with {rnn_type}...")
    model = Custom_RNN(input_size, hidden_size, output_size, num_layers=1, rnn_type=rnn_type)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    hidden_state = None  # Initialize hidden state
    for epoch in range(training_epochs):
        optimizer.zero_grad()

        # Forward Pass
        if hidden_state is None:
            if rnn_type == "LSTM":
                hidden_state = (
                    torch.zeros(1, X.size(0), hidden_size),
                    torch.zeros(1, X.size(0), hidden_size),
                )
            else:
                hidden_state = torch.zeros(1, X.size(0), hidden_size)
        else:
            if rnn_type == "LSTM":
                hidden_state = (hidden_state[0].detach(), hidden_state[1].detach())
            else:
                hidden_state = hidden_state.detach()

        outputs, hidden_state = model(X, hidden_state)  # Use hidden_state for continuity

        # Backward and Optimize
        loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
        loss.backward()
        optimizer.step()

        if epoch % 10 == 9:
            print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

    # Generate Text
    print("\nGenerating text...")
    with torch.no_grad():
        input_char = sample_sentence[:unit_sequence_length]
        input_data = [dic[char] for char in input_char]
        input_tensor = torch.zeros(1, unit_sequence_length, dic_size)
        for i, char_index in enumerate(input_data):
            input_tensor[0, i, char_index] = 1

        result = input_char
        hidden_state = None  # Initialize hidden_state
        for _ in range(len(sample_sentence) - unit_sequence_length):
            if hidden_state is None:  # Initialize hidden_state for the first time
                if rnn_type == "LSTM":
                    hidden_state = (
                        torch.zeros(1, 1, hidden_size),  # h_0
                        torch.zeros(1, 1, hidden_size),  # c_0
                    )
                else:
                    hidden_state = torch.zeros(1, 1, hidden_size)

            outputs, hidden_state = model(input_tensor, hidden_state)
            pred_char_index = outputs[0, -1, :].argmax().item()
            result += char_set[pred_char_index]

            # Update input_tensor
            input_tensor = torch.zeros(1, unit_sequence_length, dic_size)
            for i, char in enumerate(result[-unit_sequence_length:]):
                input_tensor[0, i, dic[char]] = 1

        print(f"Original: {sample_sentence}")
        print(f"Generated: {result}")

        # Accuracy
        correct = sum([1 for i in range(len(result)) if result[i] == sample_sentence[i]])
        accuracy = correct / len(result) * 100
        print(f"Accuracy: {accuracy:.2f}%")