실전 예제로 배우는 GAN

생성적 적대 신경망이란? (What is GAN?)

threegopark 2021. 5. 28. 12:37
728x90

생성적 적대 신경망(generative adversarial networks, GAN)

 

두 개 이상의 신경망이 서로를 향하게 하고, 서로 대항하듯이 훈련하게 함으로써, 결과적으로 생성 모델을 산출

 

GAN의 이점

  • 데이터가 한정된 상황에서도 일반화(Generalization) 가능
  • 작은 데이터셋을 가지고도 새로운 장면을 생각할 수 있음
  • 모조 데이터(Simulated data)를 더욱 진짜처럼 보이게 할 수 있음

 

생성 모델과 판별 모델

 

머신러닝 & 딥러닝 --> 생성 모델링 (Generative modeling) + 판별 모델링 (Discriminative modeling) 으로 설명 가능

(기존에 공부했던 분류 기법은 전형적인 판별 모델링 기술에 해당함, 단, 군집화 등의 기법은 생성 모델링에 해당)

 

기존의 판별 모델링

 

-> 그림을 살펴본 후, 해당 그림의 화풍(style)을 정하는 일 (무엇인가를 '판단'하는 일)

  1. 데이터 내의 각 부분을 이해하기 위해 합성곱 계층을 만들거나, 기타 학습된 특징들을 사용하는 머신러닝 모델 구축
  2. 훈련 집합과 검증 집합이 모두 포함된 데이터 셋 수집
  3. 모델 훈련
  4. 모델을 사용해 어떤 데이터 점이 특정 타겟에 속하는지 예측

 

판별 모델링 작동 방식

 

-> 분포에 대한 계급 간의 경계 조건을 학습

-> 데이터가 많을수록 성능이 좋아짐

-> 레이블 필요

 

 

기존의 생성 모델링

 

-> 화풍에 대한 지식을 쌓고, 다양한 화가의 화풍에 따라 그림을 '재현'해 내는 것

  1. 다양한 그림의 화풍을 '복제'하는 방법을 학습하는 머신러닝 모델부터 작성
  2. 훈련 집합과 검증 집합이 모두 포함된 데이터 셋 수집
  3. 모델 훈련
  4. 모델을 사용해 그림 작가가 그린 사례를 바탕으로 예측(추론). 즉, 유사도라는 계량기준을 사용해 모델에서 화풍을 재현하는 기능을 확인

 

생성 모델링 작동 방식

 

-> 주어진 입력의 분포에 대한 계급들의 분포를 모델링

-> 분포를 추정하기 위해 각 계급에 대한 확률 모델 생성

-> 생성 모델에서는 훈련을 하는 중에 알아서 레이블을 학습하게 되므로 레이블이 없는 데이터 사용 가능

 

+정리

생성 모델은 입력 분포를 정확하게 모델링하고 복제

판별 모델은 결정 경계들을 학습하기만 하면 됨


심층 신경망

 

https://m.blog.naver.com/tjdudwo93/221072421443

  1. 이미지 또는 그 밖의 입력 데이터 등의 입력 내용이 입력 층으로 전송
  2. 단일한 은닉 계층 또는 이어져 있는 여러 은닉 계층이 이 데이터를 바탕으로 연산한다. (역전파 알고리즘을 통해 에포크마다 각 계층의 가중치 조정)
  3. 출력 층은 모든 정보를 출력 형식으로 집계

GAN 아키텍처 주요 구성요소

  1. 케라스 & 파이토치 : 텐서플로를 벡엔드로 사용하는 프론트엔드 프레임워크 (모델 구현에 필수적인 메서드 지원)
  2. 생성기(generator) & 판별기(discriminator), 두 가지 신경망 기반 구성요소

 

 

GAN 작동 방식

 

https://wegonnamakeit.tistory.com/54

  • 위조범에 해당하는 생성기의 목표 : 경찰(판별기)이 위조지폐(Fake)와 진짜 지폐를 구별하지 못하도록 물품 생성
  • 경찰에 해당하는 판별기의 목표 : 진품과 모조품을 분류해내는 사전 경험을 바탕으로 예외적인 제품 탐지
  • 즉, 판별기가 진짜 이미지와 가짜 이미지를 구별하지 못하는 확률 목표인 0.5를 달성하도록 실제같은 fake 이미지를 생성해내는 것이다. 
  • GAN 프레임워크에서 판별기가 이미지를 분류할 수 있어야 하므로, 대체로 훈련을 시작하기 전에 몇 개의 에포크만큼 먼저 판별기부터 훈련이 시작된다. 

 

 

GAN 생성기 아키텍처 & 판별기 아키텍처

 

1. 생성기 아키텍처

  • 잠재 공간에서 표본을 추출해 잠재 공간과 출력 간의 관계를 생성하는 것이 역할이다.
  • 그 다음, 입력(잠재 공간)에서 출력(대부분 이미지)으로 향하는 신경망을 만든다.
  • 한 모델 안에서 생성기와 판별기를 서로 연결해 적대적 관계를 형성시킴으로써 생성기를 훈련한다.
  • 생성기의 훈련을 끝낸 뒤에는 생성기를 추론에 사용할 수 있다.

1-1. 코드

class Generator:
    
    def __init__(self):
        self.initVariable = 1
    
    def lossFunction(self):
        #모델 훈련 시 사용할 사용자 정의 손실 함수
        return
    
    def buildModel(self):
        #3주어진 신경망의 실제 모델 구성
        return
    
    def trainModel(self, inputX, inputY):
        
        return

 

 

2. 판별기 아키텍처

  • 진짜와 가짜를 분류(보통 이진분류에 해당)하는 데 사용할 합성곱 신경망을 만든다.
  • 진짜 데이터로만 구성된 데이터셋을 만들고, 생성기를 사용해서 가짜 데이터로만 구성된 데이터셋도 만든다.
  • 진짜 데이터와 가짜 데이터를 사용해 판별기 모델을 훈련한다.
  • 생성기를 훈련함으로써 훈련된 판별기와 서로 균형을 잡는 방법을 학습한다. (판별기가 너무 뛰어나게 되면 생성기가 발산하게 된다는 점을 이용한다.)
  • 데이터의 근원 분포에 적응할 수 있는 판별기를 훈련하는 것이 최종 목표이다.
  • 진짜 이미지는 척도가 가리키는 점수를 높이는 반면에, 가짜 이미지는 척도가 가리키는 점수를 낮추는 방향으로 평가한다.

 

(+ 근원 분포에 적응한다??

 

  • 녹색 선 - 가짜 데이터의 분포 (검정 점선과 최대한 일치하도록 조정되는 것이 목표)
  • 파란 선 - 분류 분포 (훈련을 반복하면 가장 구분하기 어려운 확률 분포인 0.5가 됨)
  • 검정 점선 - 실제 데이터의 분포

 

2-1. 코드

class Discriminator:
    def __init__(self):
        self.initVariable = 1
        
    def lossFunction(self):
        
        return
    
    def buildModel(self):
        
        return
    
    def trainModel(self, inputX, inputY):
        
        return

 

 

3. 손실 함수

  • 신경망을 훈련하는데 필요한 구조 요소
  • 훈련 과정 중 가중치를 조절하여 손실 함수가 최적화되게 한다.
  • 목적에 맞는 손실 함수가 필요하다.

 

 

3-1. 생성기용 손실 함수

 

  • 판별기가 옳게 판단하였는지를 나타내는 로그 확률을 판별기가 줄여 나가고 있다는 점을 간단히 나타낸다.

 

3-2. 판별기용 손실 함수

 

  • 표준 교차 엔트로피 식이다.
  • 두 가지 오차 함수가 각기 최소화/최대화되어야 한다. (생성기에 대한 오차 최대화, 판별기에 대한 오차 최소화)

 

3-3. 코드

class Loss:
    def __init__(self):
        self.initVariable = 1
        
    def lossBaseFunction1(self):
        
        return
    
    def lossBaseFunction2(self):
        
        return
    
    def lossBaseFunction3(self):
        
        return