GAN은 Generative Adversarial Networks라는 의미로서, 한국어로 번역하면 '생성적 적대 신경망' 정도로 번역할 수 있겠다.
현재로서 활발히 많이 이용되는 알고리즘은 대부분이 '지도학습' 부류로 CNN을 활용한 Object Detection, LSTM 및 Transformer 를 활용한 번역문제가 있고, 강화학습의 경우 자율주행 분야 및 로봇 제어 분야에서 사용되고 있으나, 유독 '비지도 학습' 의 경우 마땅히 시각적으로 보여줄만한 결과가 없었다. 하지만 2014년 GAN 모델이 출시되고 나서는 얘기가 좀 달라졌는데, Stable Diffusion 을 비롯한 GAN 의 자식 모델들이 '이미지 생성' 영역에서 큰 역할을 하고 있기 때문이다.
위 이미지들은 모두 stability.ai 에서 만든 Stable Diffusion 모델로 제작한 그림이다.
각각 이미지들을 살펴보면 몇몇 AI그림임을 추측할 수 있는 부분이 존재한다.
- stable difusion의 고질적 문제로 손가락이 6개이며, 엄지손가락의 길이도 너무 길다.
- 햇빛에 비친 옷의 질감이 이질적이며, 제복임에도 불구하고 단추마다 모양이 다르다.
- 금속에서 반사되는 빛이 부자연스러우며, 허벅지 흰색부분이 대칭이 아닌점이 부자연스럽다.
- 오른쪽 귀가 부자연스럽게 잘린것처럼 보이며, 피부 질감이 묘하게 이질적이다.
이러한 AI 그림들에 관해서 여러가지 할 이야기가 있지만, 해당 포스팅의 주제를 벗어나므로 추후 기회가 된다면 이야기해 보고자 한다.
중요한 점은 이러한 생성 모델들의 중심에 GAN이 존재한다는 점이다.
이번 포스팅은 책 GAN 첫걸음을 읽으며 정리한 GAN 모델의 원리와, 손글씨 숫자 데이터를 생성해 보고자 한다.
해당 포스팅에서 수식이나, 어려운 약어의 사용은 지양하고자 한다.
1. GAN 모델의 시작
아마 모든 AI 연구자, 학생들은 MNIST 분류기를 통해서 28 x 28 픽셀의 손글씨를 인식하는 CNN 모델을 제작해 본 경험이 있을것이다.
여기서 우리는 CNN 모델을 활용하여 숫자가 쓰여진 28 x 28 픽셀의 형상을 보곤 어떤 숫자인지 맞추는 인공지능을 제작하였다. 위 그림의 정답은 순서대로 "5,0,4,1,9,2,1,3,1" 이 될 것이다.
여기서 재미있는 발상이 떠오른다. 그걸 '반대로' 할 순 없을까?
이미지를 토대로 → 숫자를 추출하는게 가능하다면
숫자를 토대로 → 이미지를 생성할 수는 없을까?
GAN은 이 발상부터 시작한다.
2. GAN 모델이란?
위에서 숫자를 → 이미지로 생성하기 위해서는 그냥 일반적인 CNN 학습을 수행해도 어느정도 가능할 테지만, 여러 문제가 존재한다.
- 위 그림과 몇 픽셀이 다른 이미지가 생성되면 해당 이미지는 '4' 가 아닌가?
- '4'가 맞다면 어떤 근거로 해당 이미지를 '4'로 판단할 수 있을것인가?
- '4'가 아니라면 어떤 근거로 해당 이미지를 '4'가 아니라고 판단할 수 있을것인가?
- 위의 모든 과정, 생성-판단에서 사람이 개입해서는 안된다.
위와같은 문제들을 해결하고자 2014년 이언 굿펠로는 GAN 을 제안한다. 해당 논문에서는 '범죄자'와 '범죄를 막는 경찰' 에 빗대어 GAN을 설명한다.
해당 포스팅은 논문 리뷰가 아니기에 짧게 짚고 넘어가자면, "생성적 적대 신경망" 이라는 말을 다시 상기해 보자. 해당 모델에서는 '판별기', '생성기' 두 가지 모델이 서로 적대적으로 싸운다. '생성기' 는 임의로 데이터를 생성한다, 그리고 '판별기' 는 해당 모델이 실제로 제공된 데이터인지, 아니면 '생성기' 가 임의적으로 생성한 데이터인지를 판별한다. 따라서 아래와 같은 성격을 갖는다.
- 생성기(Generator) : 판별기를 속여 생성기가 생성한 데이터가 '참' 이라는 판정을 얻어야 한다.
- 판별기(Discriminator) : 생성기가 생성한 데이터를 걸러내야 한다.
여기서 당연하게도 생성기는 판별기로부터 '참' 이라는 결과를 얻으려면 올바른 데이터를 생성하여야 한다. 즉, 진짜 사람이 쓴것과 유사한 이미지 데이터를 생성해 내야만 한다.
판별기는 생성기가 생성한 데이터를 걸러내어 'False' 판정을 내려주고, 원본 데이터에는 'True' 판정을 내려주어야만 한다.여기서 매력적인 부분은, 판별기가 훈련을 거치며 점점 성능이 좋아질수록, 생성기 또한 보상과 벌을 통해 훈련이 되어 성능이 증가할 것이라는 점이다. 궁극적으로 생성기는 진짜 이미지와 분간이 가지 않는 이미지들을 만들기 시작할 것이다.
판별기와 생성기는 서로 적대적 관계로 경쟁을 하며 서로를 뛰어 넘으려고 노력하기에, 결과적으로는 둘 다 성능이 좋아지게 되고, 이것이 바로 생성적 적대 신경망(GAN) 이다.
3. 코드를 통해 GAN의 동작원리를 되새겨 보자.
Source Code - Import
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas, numpy, random
import matplotlib.pyplot as plt
Source Code - Load MNIST Data
class MnistDataset(Dataset):
def __init__(self, csv_file):
self.data_df = pandas.read_csv(csv_file, header=None)
def __len__(self):
return len(self.data_df)
def __getitem__(self, index):
# image target (label)
label = self.data_df.iloc[index,0]
target = torch.zeros((10))
target[label] = 1.0
# image data, normalised from 0-255 to 0-1
image_values = torch.FloatTensor(self.data_df.iloc[index,1:].values) / 255.0
# return label, image data tensor and target tensor
return label, image_values, target
def plot_image(self, index):
img = self.data_df.iloc[index,1:].values.reshape(28,28)
plt.title("label = " + str(self.data_df.iloc[index,0]))
plt.imshow(img, interpolation='none', cmap='Blues')
mnist_dataset = MnistDataset('mount/mnist_train.csv')
mnist_dataset.plot_image(17)
https://pjreddie.com/media/files/mnist_train.csv - 훈련용 데이터
https://pjreddie.com/media/files/mnist_test.csv - 테스트용 데이터
간단히 MNIST 손글씨 데이터를 불러오는 코드이다. 생성자로 csv_file 의 Path 를 전달받아 읽어낸뒤 mnist_dataset 인스턴스로 지정한다.
그 후 해당 데이터셋에서 17번행의 데이터를 불러와 가시화한다.
Source Code - Data Function
def generate_random_image(size):
random_data = torch.rand(size)
return random_data
def generate_random_seed(size):
random_data = torch.randn(size)
return random_data
torch.rand() = 그냥 고르게 랜덤.
torch.randn() = 정규분포, 분산제한을 걸어둔 random function
Source Code - Discriminator
class Discriminator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 1),
nn.Sigmoid()
)
# create loss function
self.loss_function = nn.BCELoss()
# create optimiser, simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, inputs, targets):
# calculate the output of the network
outputs = self.forward(inputs)
# calculate loss
loss = self.loss_function(outputs, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
if (self.counter % 10000 == 0):
print("counter = ", self.counter)
# zero gradients, perform a backward pass, update weights
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
판별기의 코드는 비교적 간단하다. 28 x 28 = 784 개의 입력을 받아 최종적으로 단일 출력을 내보낸다.(여기서 단일 출력은 0 ~ 1 사이의 소수이다.)
Optimizer 로는 Adam 을 사용하고, Loss Function 은 이진 교차 엔트로피를 사용한다(BCELoss)
Source Code - Generator
class Generator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
nn.Linear(100, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 784),
nn.Sigmoid()
)
# create optimiser, simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
def forward(self, inputs):
# simply run model
return self.model(inputs)
def train(self, D, inputs, targets):
# calculate the output of the network
g_output = self.forward(inputs)
# pass onto Discriminator
d_output = D.forward(g_output)
# calculate error
loss = D.loss_function(d_output, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
# zero gradients, perform a backward pass, update weights
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
생성기의 코드는 판별기에 비해서 조금 복잡하다. 100개의 입력을 받아 784개의 입력으로 출력한다 (이미지로 변환)
더하여 생성자 부분을 살펴보면 loss function 를 설정하지 않는데, 이유는 당연히 '판별기' 로부터 loss 를 전달받아야 하기 때문이다. 때문에 Train 부분에 판별기 D 를 입력받아 해당 판별기로부터, 직접 loss 를 전달받는다. 이 때 targets 변수는 1.0 으로, 반드시 참이 나와야 올바른 생성기 이기에 생성기가 출력한 loss와, 1.0 값간의 상호 비교를 통한 loss 값을 전달받아 생성기를 갱신하는 원리이다.
참고로 왜 단일 입력 (0~9) 가 아닌 100개의 입력인가 하면 이는 모드붕괴에 따른 문제로 단 하나의 입력값 만으로 신경망이 784개의 픽셀을 온전히 만드는 것이 힘들기 때문이다. 때문에 100개의 입력으로 그 입력을 증가시키고, 그 값들은 generate_ramdom_seed() 를 통해 생성된다.
Source Code - Test Generator
G = Generator()
output = G.forward(generate_random_seed(100))
img = output.detach().numpy().reshape(28,28)
plt.imshow(img, interpolation='none', cmap='Blues')
Source Code - TRAIN GAN
%%time
# create Discriminator and Generator
D = Discriminator()
G = Generator()
epochs = 4
for epoch in range(epochs):
print ("epoch = ", epoch + 1)
# train Discriminator and Generator
for label, image_data_tensor, target_tensor in mnist_dataset:
# train discriminator on true
D.train(image_data_tensor, torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))
드디어 Generator, Discriminator 를 학습시킨다.
첫 D.train() 를 호출할 때는 실제 이미지 데이터와, 해당 이미지 데이터가 올바르다는 의미로 [1.0] 을 target 으로 삽입한다.
두번째 D.train() 을 호출할 때는 Generator 에서 생성된 데이터를 삽입하고, 해당 데이터가 잘못되었다는 의미로 [0.0] 을 삽입한다.
세번째 G.Train은 D(loss를 추가하기 위해) Random Seed 값 [1.0] 을 target 으로 넣어준다. 이는 Generator 가 판별기로부터 평가받아야 하는 값이 1.0 이기 때문이다.
실제 generator 를 보면 1.0 값과 판별기에서 도출된 d_output 을 가지고 loss function 을 계산한다.
Output - TRAIN GAN
epoch = 1
counter = 10000
counter = 20000
counter = 30000
counter = 40000
counter = 50000
counter = 60000
counter = 70000
counter = 80000
counter = 90000
counter = 100000
counter = 110000
counter = 120000
epoch = 2
counter = 130000
counter = 140000
counter = 150000
counter = 160000
counter = 170000
counter = 180000
counter = 190000
counter = 200000
counter = 210000
counter = 220000
counter = 230000
...
counter = 470000
counter = 480000
CPU times: user 1h 22min 15s, sys: 5.23 s, total: 1h 22min 20s
Wall time: 13min 43s
Source Code - Discriminator Loss
D.plot_progress()
판별기의 loss 값은 빠르게 0으로 수렴, 유지되지만 때때로 점프가 발생하는것이 눈의 띄인다. 이는 판별,생성기 간의 균형이 맞지 않는다는 의미이다.
Source Code - Generator Loss
G.plot_progress()
생성기의 loss 값은 처음에 튀어오르는 것을 볼 수 있는데, 이는 생성기가 판별기에 비해 초반부 성능이 뒤져진다는 의미이다. 손실이 떨어진 이후에는 3 근처로 머무는 모습을 볼 수 있다.
참고로 BCELoss() 의 최댓값은 1.0 으로 제한되어 있지 않으며, BCELoss() 를 기준으로 생성기와 판별기가 모두 0.69 값에 수렴하는 값을 나타내야만 어느 하나가 적을 압살하지 않고, 올바르게 겨루며 성장하고 있다는 의미이다. (이러한 수렴값은 loss 함수마다 다르다.)
Source Code - Run Generator
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
for j in range(3):
output = G.forward(generate_random_seed(100))
img = output.detach().numpy().reshape(28,28)
axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
pass
pass
위와 같이 학습을 진행하고 나서 이미지를 살펴보면 나름 괜찮은 이미지들을 생성해 낸 것을 확인할 수 있다.
마치며
최근에 들어서 생성형 인공지능 AI가 대두되고 있고 특히 이미지 제작 같은 영역에서 대부분의 모델들은 GAN 을 뿌리로 두고 있다. 이러한 GAN은 사실 기본 원리가 LSTM이나, Transformer 와 같은 Recurrency 계열 모델들보다 쉽다고 생각하기에 쉽게 배울 수 있을것이다. (다만 기본 원리만 쉽지 학습시키기엔 훨신 빡세다)
그리고 지금까지 맛본 봐와 같이 순수 GAN은 성능이 꽤 떨어진다. 이러한 성능 문제를 개선하기 위해서 책에서도 여러가지 기술들에 대해서 언급 해 주고있으며, 웹상을 뒤져보면 여러가지 최적화 기법들이 많이 있다. 다만 이번 포스팅의 목적이 GAN이 무엇인지 소개하는데에 있으므로 여기까지 작성하도록 하겠다.
'Artificial Intelligence > Basic' 카테고리의 다른 글
GAN의 최적 손실값 (1) | 2023.06.19 |
---|---|
CGAN(Conditional GAN) (0) | 2023.06.18 |
벨만 방정식(Bellman Equation) (0) | 2023.05.29 |
마르코프 결정 프로세스(Markov Decision Process) (0) | 2023.05.28 |
어텐션이란? (0) | 2023.04.14 |