GAN (Generative Adversarial Network) 정리
딥러닝(Deep-Learning)

GAN (Generative Adversarial Network) 정리

반응형

GAN (Generative Adversarial Network)

머신러닝은 크게 3가지 분류로 이루어진다. (지도학습, 강화학습, 비지도학습)

GAN은 '비지도 학습(Unsupervised Learning)'에 해당

Yann Lecun 교수 "근 10년 안에 나온 딥러닝 아이디어 중 최고의 생성 모델"



개념

Adversarial : 대립하는 (즉, 크게 두 부분으로 나누어져 있다는 것을 알 수 있음)

Image를 만들어내는 것(Generator - 생성자)

만들어진걸 평가하는 것(Discriminator - 구분자)

이 두개가 서로 대립하며, 서로의 성능을 점차 개선해 나가는 것이 주요 개념!




개념을 쉽게 예시로 들면?

지폐위조범(Generator)은 경찰을 최대한 열심히 속이려고 한다.

경찰(Discriminator)은 이렇게 위조된 지폐를 진짜와 감별하려고(Classify) 노력한다.

이러한 경쟁 속에서 두 그룹 모두 속이고 구별하는 서로의 능력이 발전하고, 
결과적으로는 진짜 지폐와 위조 지폐를 구별할 수 없을 정도(구별할 확률 pd=0.5)에 이르게 된다는 것!



예시를 수학적 용어로 풀면?

Generative model 'G'는 우리가 갖고 있는 data x의 distribution을 알려고 노력함


만약, G가 정확히 data distribution을 묘사할 수 있다면, 거기서 뽑은 샘플은 data와 구별이 불가능할 것이다. 

(진짜인지 가짜인지 구별할 수 없다.)


Discriminator model 'D'는 현재 자기가 보고 있는 샘플이 training data에서 온 것(진짜 데이터)인지, 

아니면 'G'로부터 만들어진 것인지 구별해서 각각의 경우에 대한 확률을 평가한다.




G : 생성된 z를 받아서 실제 데이터와 비슷한 데이터를 만들어내도록 학습

D : 실제 데이터와 G가 생성한 가짜 데이터를 구별하도록 학습



궁극적인 목적

실제 데이터의 분포에 가까운 데이터를 생성하는 것!

생성자(G)는 구분자(D)가 거짓으로 판별하지 못하도록 가짜 데이터를 생성

이 과정을 통해 두 모델의 성능이 올라가 최종적으로 구분자(D)가 실제 데이터와 가짜 데이터를 구분하지 못하게 만드는 것이 목표다.



생성자(G)는 원 데이터의 확률 distribution(분포)를 알아내려고 노력하며 이 분포를 재현하여 차이가 없도록 하고, 구분자(D)는 구분 대상인 데이터가 실 데이터인지, G가 만들어낸 데이터인지 구별해서 각각에 대한 확률을 추정한다.




  • GAN 수식



엄청나게 복잡해 보인다ㅠ 말로 풀어서 이해해보자.

  • D가 V(D,G)를 최대화 하는 관점

x~p data(x)는 실제 데이터의 확률분포이고 x는 그 중 샘플링한 데이터다.
구분자인 D는 출력이 실제 데이터가 들어오면 1에 가깝게 확률을 추정하고, G가 만들어낸 가짜 데이터가 들어오면 0에 가깝게 한다.
Log를 사용했기 때문에 실제 데이터면 log 1, 즉 최댓값인 0에 가까운 값이 나오며 가짜 데이터면 -∞ 로 발산해서 V(D,G)를 최대화 하는 방향으로 학습하게 된다.


  • G가 V(D,G)를 최소화 하는 관점
z~pz(z)는 보통 정규분포로 사용하는 임의의 노이즈 분포이고, z는 노이즈 분포에서 샘플링한 임의의 코드다.
이 입력을 생성자 G에 넣어 만든 가짜 데이터를 구분자 D가 속아서 진짜로 판별된다면, log(1-D(G(z))) 식에서 D(G(z)) = 1 값이 들어가 log 0이 되어 -∞로 발산한다.
이에 반해 구분자 D를 속이지 못하면, D(G(z)) = 0 값이 들어가 log 1이 되어 0에 가까운 최댓값이 나오게 된다. 따라서 G는 V(D,G)를 최소화하는 방향으로 학습하게 된다.


즉, V(D,G)에 있어서 G는 이를 '최소화'하는 방향으로 가고 D는 '최대화'하는 방향으로 가게하는 minimax Problem을 보여주고 있다.





  • GAN 학습

GAN의 G와 D는 어떻게 학습을 진행할까?

'확률분포'에 집중해야 한다. D는 G와 기존 확률분포가 얼마나 다른지 판별한다.

그리고 G는 기존 확률분포에 맞춰 D를 소깅기 위해 생성 모델을 수정해나간다.




이 과정은, V(D,G)를 minimax하는 과정이 G가 만드는 확률분포와 원 데이터의 확률분포 차이를 줄여나가는 과정을 보여준다.


검은 점선 : data generating distribution

파란 점선 : discriminator distribution

녹색 선 : generative distribution


(a)에서는 녹색 선(g)와 검은 점선(data)가 전혀 다르게 생긴 것을 볼 수 있다.

이 상태에서 구분자(D)를 통해 두 distribution을 구별하기 위한 학습을 시키면, (b)와 같이 더 부드러운 distribution이 만들어진다. 이후에 G가 현재 D가 구별하기 어려운 방향으로 학습을 하게 되면, (c)처럼 더 가까워지고 이를 계속 반복 학습하면 결국 검은 점선(data)와 녹색 선(g)이 같게 되어 구분자(D)로 구별하지 못하는 D(x) = 0.5 상태(파란 점선)가 되는 것이다.

여기서, 확률분포간의 차이를 계산하기 위해서 'JSD'를 사용한다. JSD는 두 확률 분포사이의 차이를 계산하는 것이다. (이 식을 사용하기 위해선 KLD를 알고 있어야 함)

JSD는 두개의 KLD를 통해 이루어짐

P : 원 확률분포

Q : G 확률분포

M : 원 확률분포와 G 확률분포의 평균


P와 M과 Q와 M을 각각 KLD하고 평균 값을 구해서 두 확률분포 간의 차이를 구하는 것

이러한 발산 과정을 통해서, 원 데이터의 확률분포와 G가 생성해낸 확률분포 간의 JSD가 0이 되면 두 분포간의 차이가 없다는 것으로 학습이 완료되었다는 것을 나타낸다.


  • KLD 공식


이를 활용한 minimax problem을 잘 푼다면, G가 만든 probability distribution이 data distribution과 정확히 일치하도록 할 수 있는 걸 알 수 있다.

즉, 생성자 G로 뽑아낸 샘플을 구분자 D가 실제와 구별할 수 없게 된다는 뜻!

그렇다면, 우리는 어떤 모델과 알고리즘을 사용해야 이 문제를 잘 풀 수 있을까?



반응형