머신러닝은 크게 3가지 분류로 이루어진다. (지도학습, 강화학습, 비지도학습)
GAN은 '비지도 학습(Unsupervised Learning)'에 해당
Yann Lecun 교수 "근 10년 안에 나온 딥러닝 아이디어 중 최고의 생성 모델"
개념
Adversarial : 대립하는 (즉, 크게 두 부분으로 나누어져 있다는 것을 알 수 있음)
Image를 만들어내는 것(Generator - 생성자)
만들어진걸 평가하는 것(Discriminator - 구분자)
이 두개가 서로 대립하며, 서로의 성능을 점차 개선해 나가는 것이 주요 개념!
개념을 쉽게 예시로 들면?
지폐위조범(Generator)은 경찰을 최대한 열심히 속이려고 한다.
경찰(Discriminator)은 이렇게 위조된 지폐를 진짜와 감별하려고(Classify) 노력한다.이러한 경쟁 속에서 두 그룹 모두 속이고 구별하는 서로의 능력이 발전하고,결과적으로는 진짜 지폐와 위조 지폐를 구별할 수 없을 정도(구별할 확률 pd=0.5)에 이르게 된다는 것!
예시를 수학적 용어로 풀면?
Generative model 'G'는 우리가 갖고 있는 data x의 distribution을 알려고 노력함
만약, G가 정확히 data distribution을 묘사할 수 있다면, 거기서 뽑은 샘플은 data와 구별이 불가능할 것이다.
(진짜인지 가짜인지 구별할 수 없다.)
Discriminator model 'D'는 현재 자기가 보고 있는 샘플이 training data에서 온 것(진짜 데이터)인지,
아니면 'G'로부터 만들어진 것인지 구별해서 각각의 경우에 대한 확률을 평가한다.
D : 실제 데이터와 G가 생성한 가짜 데이터를 구별하도록 학습
궁극적인 목적
실제 데이터의 분포에 가까운 데이터를 생성하는 것!
생성자(G)는 구분자(D)가 거짓으로 판별하지 못하도록 가짜 데이터를 생성
이 과정을 통해 두 모델의 성능이 올라가 최종적으로 구분자(D)가 실제 데이터와 가짜 데이터를 구분하지 못하게 만드는 것이 목표다.
생성자(G)는 원 데이터의 확률 distribution(분포)를 알아내려고 노력하며 이 분포를 재현하여 차이가 없도록 하고, 구분자(D)는 구분 대상인 데이터가 실 데이터인지, G가 만들어낸 데이터인지 구별해서 각각에 대한 확률을 추정한다.
- GAN 수식
엄청나게 복잡해 보인다ㅠ 말로 풀어서 이해해보자.
x~p data(x)는 실제 데이터의 확률분포이고 x는 그 중 샘플링한 데이터다.구분자인 D는 출력이 실제 데이터가 들어오면 1에 가깝게 확률을 추정하고, G가 만들어낸 가짜 데이터가 들어오면 0에 가깝게 한다.Log를 사용했기 때문에 실제 데이터면 log 1, 즉 최댓값인 0에 가까운 값이 나오며 가짜 데이터면 -∞ 로 발산해서 V(D,G)를 최대화 하는 방향으로 학습하게 된다.
z~pz(z)는 보통 정규분포로 사용하는 임의의 노이즈 분포이고, z는 노이즈 분포에서 샘플링한 임의의 코드다.이 입력을 생성자 G에 넣어 만든 가짜 데이터를 구분자 D가 속아서 진짜로 판별된다면, log(1-D(G(z))) 식에서 D(G(z)) = 1 값이 들어가 log 0이 되어 -∞로 발산한다.이에 반해 구분자 D를 속이지 못하면, D(G(z)) = 0 값이 들어가 log 1이 되어 0에 가까운 최댓값이 나오게 된다. 따라서 G는 V(D,G)를 최소화하는 방향으로 학습하게 된다.
G는 이를 '최소화'하는 방향으로 가고 D는 '최대화'하는 방향으로 가게하는 minimax Problem을 보여주고 있다.
- GAN 학습
GAN의 G와 D는 어떻게 학습을 진행할까?
'확률분포'에 집중해야 한다. D는 G와 기존 확률분포가 얼마나 다른지 판별한다.
그리고 G는 기존 확률분포에 맞춰 D를 소깅기 위해 생성 모델을 수정해나간다.
이 과정은, V(D,G)를 minimax하는 과정이 G가 만드는 확률분포와 원 데이터의 확률분포 차이를 줄여나가는 과정을 보여준다.
검은 점선 : data generating distribution
파란 점선 : discriminator distribution
녹색 선 : generative distribution
이 상태에서 구분자(D)를 통해 두 distribution을 구별하기 위한 학습을 시키면, (b)와 같이 더 부드러운 distribution이 만들어진다. 이후에 G가 현재 D가 구별하기 어려운 방향으로 학습을 하게 되면, (c)처럼 더 가까워지고 이를 계속 반복 학습하면 결국 검은 점선(data)와 녹색 선(g)이 같게 되어 구분자(D)로 구별하지 못하는 D(x) = 0.5 상태(파란 점선)가 되는 것이다.
여기서, 확률분포간의 차이를 계산하기 위해서 'JSD'를 사용한다. JSD는 두 확률 분포사이의 차이를 계산하는 것이다. (이 식을 사용하기 위해선 KLD를 알고 있어야 함)
JSD는 두개의 KLD를 통해 이루어짐
P : 원 확률분포
Q : G 확률분포
M : 원 확률분포와 G 확률분포의 평균
P와 M과 Q와 M을 각각 KLD하고 평균 값을 구해서 두 확률분포 간의 차이를 구하는 것
이러한 발산 과정을 통해서, 원 데이터의 확률분포와 G가 생성해낸 확률분포 간의 JSD가 0이 되면 두 분포간의 차이가 없다는 것으로 학습이 완료되었다는 것을 나타낸다.
즉, 생성자 G로 뽑아낸 샘플을 구분자 D가 실제와 구별할 수 없게 된다는 뜻!
그렇다면, 우리는 어떤 모델과 알고리즘을 사용해야 이 문제를 잘 풀 수 있을까?
'딥러닝(Deep-Learning)' 카테고리의 다른 글
모델에서 이루어지는 '딥러닝' (0) | 2018.07.03 |
---|---|
모델에 데이터를 학습시키는 과정(MNIST) (0) | 2018.07.03 |
[딥러닝] Backpropagation & L2 Regularization 정리 (0) | 2018.05.31 |
[딥러닝] 머신러닝 & Regression 정리 (0) | 2018.05.28 |
[딥러닝] Linear Regression 코드 실행 및 코드 분석 (0) | 2018.05.14 |