[pix2pix 모델] 코드 리뷰
딥러닝(Deep-Learning)

[pix2pix 모델] 코드 리뷰

반응형

[pix2pix] 텐서플로우 코드 리뷰



pix2pix github 코드 : https://github.com/yenchenlin/pix2pix-tensorflow/blob/



GAN을 이용한 이미지 변환 모델인 pix2pix를 텐서플로우로 구현한 코드 리뷰 포스팅입니다.


main.py와 model.py에 해당하는 코드의 내용을 이해하기 위해 한줄씩 리뷰해보고자 합니다.


우선 main.py부터 알아보고 가겠습니다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset_name', dest='dataset_name', default='facades', help='name of the dataset')
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate')
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation')
parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
parser.add_argument('--phase', dest='phase', default='train', help='train, test')
parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=50, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)')
parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)')
parser.add_argument('--print_freq', dest='print_freq', type=int, default=50, help='print the debug information every print_freq iterations')
parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=100.0, help='weight on L1 term in objective')
 
args = parser.parse_args()
cs



시작부터 엄청 복잡해보입니다. 이는 커맨드라인 인자를 다뤄주는 파이썬의 argparse를 활용한 코드인데요.

main.py의 시작부분에 이를 이용할 argparse가 import되는 모습을 확인할 수 있습니다.


parser = argparse.ArgumentParser(description='')


이를 통해 parser를 생성시켜주고요.

parser.add_argument를 이용해 입력받고자 하는 인자의 조건들을 설정하고 있는 모습입니다.



args = parser.parse_args()


parse_args 메소드를 통해 인자들을 파싱하여 args에 저장했습니다. 이제 각 인자들은 add_argumnet의 type에 지정된 형식으로 저장이 되겠습니다.



이어서 main을 정의하고 있습니다.


1
2
3
4
5
6
7
def main(_):
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)
cs


os 모듈이 사용되는 모습입니다. os는 파일이나 디렉토리와 관련된 기능을 제공합니다.

즉, 조건문을 통해 경로에 checkpoint_dir, sample_dir, test_dir이 존재하지 않다면, 해당 디렉토리를 makedirs로 생성해주는 모습을 볼 수 있습니다.


1
2
3
4
5
6
7
8
9
10
11
 
    with tf.Session() as sess:
        model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size,
                        output_size=args.fine_size, dataset_name=args.dataset_name,
                        checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir)
 
        if args.phase == 'train':
            model.train(args)
        else:
            model.test(args)
 
cs


이어서 Session이 진행되는 모습을 볼 수 있구요. model이라는 변수에 pix2pix 클래스를 저장하고 있습니다.

phase가 train이면 train을 진행하고, 아니면 test를 진행하게 됩니다. 


즉 train이라는 phase가 존재할 때는 pix2pix의 학습을 진행하고, 없을 때는 test 과정을 진행하는 것으로 이해할 수 있습니다.


1
2
if __name__ == '__main__':
    tf.app.run()
cs


main.py의 마지막 부분입니다.


if __name__ == '__main__':


이게 무슨 의미일까요? 이 코드는 현재 스크립트 파일이 실행되는 상태를 파악하기 위해 사용하게 됩니다.

__name__은 모듈의 이름이 저장되는 변수로, import로 모듈을 가져오면 모듈의 이름이 들어가게 됩니다. 하지만 직접 파일을 실행하면 __main__이 모듈의 이름으로 들어가게 되는 걸 뜻합니다.


즉, 이 main.py를 import한 것이 아닌, 직접 실행했을 때 이 조건문이 만족하겠죠?


tf.app.run()

간단히 말해서, 주요 기능을 호출해 모든 인수를 전달하는 역할을 하는 메소드입니다. 이를 통해 pix2pix 모델을 만들어 main.py에서 학습 및 테스트를 진행할 수 있습니다.




이제 model.py에 대해 알아보도록 하겠습니다.


model 구현에 필요한 import문을 시작으로, pix2pix의 클래스 정의가 시작됩니다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class pix2pix(object):
    def __init__(self, sess, image_size=256,
                 batch_size=1, sample_size=1, output_size=256,
                 gf_dim=64, df_dim=64, L1_lambda=100,
                 input_c_dim=3, output_c_dim=3, dataset_name='facades',
                 checkpoint_dir=None, sample_dir=None):
 
        self.sess = sess
        self.is_grayscale = (input_c_dim == 1)
        self.batch_size = batch_size
        self.image_size = image_size
        self.sample_size = sample_size
        self.output_size = output_size
 
        self.gf_dim = gf_dim
        self.df_dim = df_dim
 
        self.input_c_dim = input_c_dim
        self.output_c_dim = output_c_dim
 
        self.L1_lambda = L1_lambda
cs


변수를 간단히 정리해보면 다음과 같습니다.


sess : 텐서플로우 세션 담당

batch_size : 학습할 때 사이즈 지정

output_size : output으로 나오는 이미지의 픽셀 지정

gf_dim : 첫번째 conv 레이어 생성자 필터의 dimension

df_dim : 첫번째 conv 레이어 구분자 필터의 dimension

input_c_dim : input 이미지 color의 dimention

output_c_dim : output 이미지 color의 dimention



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
        self.d_bn1 = batch_norm(name='d_bn1')
        self.d_bn2 = batch_norm(name='d_bn2')
        self.d_bn3 = batch_norm(name='d_bn3')
 
        self.g_bn_e2 = batch_norm(name='g_bn_e2')
        self.g_bn_e3 = batch_norm(name='g_bn_e3')
        self.g_bn_e4 = batch_norm(name='g_bn_e4')
        self.g_bn_e5 = batch_norm(name='g_bn_e5')
        self.g_bn_e6 = batch_norm(name='g_bn_e6')
        self.g_bn_e7 = batch_norm(name='g_bn_e7')
        self.g_bn_e8 = batch_norm(name='g_bn_e8')
 
        self.g_bn_d1 = batch_norm(name='g_bn_d1')
        self.g_bn_d2 = batch_norm(name='g_bn_d2')
        self.g_bn_d3 = batch_norm(name='g_bn_d3')
        self.g_bn_d4 = batch_norm(name='g_bn_d4')
        self.g_bn_d5 = batch_norm(name='g_bn_d5')
        self.g_bn_d6 = batch_norm(name='g_bn_d6')
        self.g_bn_d7 = batch_norm(name='g_bn_d7')
 
        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir
        self.build_model()
cs


batch normalization을 진행하고 있는 모습입니다. 여기서 생성된 변수들은 생성자나 구분자에서 사용될 것입니다.


batch_norm 메소드는 ops.py에 정의되어 있습니다.

1
2
3
4
5
6
7
8
9
10
11
class batch_norm(object):
def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
        with tf.variable_scope(name):
            self.epsilon = epsilon
            self.momentum = momentum
            self.name = name
 
    def __call__(self, x, train=True):
        return tf.contrib.layers.batch_norm(x, decay=self.momentum, 
updates_collections=None, epsilon=self.epsilon, scale=True, scope=self.name)
 
cs


이는 학습 속도를 개선시켜주고, overfitting 위험을 줄여주거나 gradient가 소실되는 문제를 해결해주는 도움을 줍니다.


여기까지가 pix2pix 클래스의 생성자에 해당하는 __init__을 정의한 코드입니다.


이 밖에 model.py에 정의된 메소드는 모델 로직을 만드는 build_model(), 학습을 진행하는 train(), GAN에 필요한 generator()와 discriminator(), 실제 학습된 pix2pix를 실험해 볼 test() 등으로 구성되어 있습니다. 하나씩 알아보도록 하겠습니다.



build_model 메소드


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    def build_model(self):
        self.real_data = tf.placeholder(tf.float32,
                                        [self.batch_size, self.image_size, self.image_size,
                                         self.input_c_dim + self.output_c_dim],
                                        name='real_A_and_B_images')
 
        self.real_B = self.real_data[:, :, :, :self.input_c_dim]
        self.real_A = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
 
        self.fake_B = self.generator(self.real_A)
 
        self.real_AB = tf.concat([self.real_A, self.real_B], 3)
        self.fake_AB = tf.concat([self.real_A, self.fake_B], 3)
        self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False)
        self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)
 
        self.fake_B_sample = self.sampler(self.real_A)
 
        self.d_sum = tf.summary.histogram("d", self.D) //가중치
        self.d__sum = tf.summary.histogram("d_", self.D_)
        self.fake_B_sum = tf.summary.image("fake_B", self.fake_B)
 
        self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \
                        + self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B))
 
        self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
        self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
 
        self.d_loss = self.d_loss_real + self.d_loss_fake
 
        self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
 
        t_vars = tf.trainable_variables()
 
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
 
        self.saver = tf.train.Saver()
cs


real_data를 'real_A_and_B_images'이름을 가진 float 타입으로 placeholder로 지정했습니다.


shape은 [self.batch_size, self.image_size, self.image_size, self.input_c_dim + self.output_c_dim]인데요.

처음 init에서 초기화한 값을 넣어보면 아래와 같습니다.


real_data = [1, 256, 256, 6]


1
2
self.real_B = self.real_data[:, :, :, :self.input_c_dim]
self.real_A = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
cs


real_data의 shape을 이용하고 있는 모습입니다. 그냥 :로 나와있는 배열은 real_data의 값을 사용하면 됩니다.


real_B = [1, 256, 256, 3]

real_A = [1, 256, 256, 3:6]


이어서 fake_B를 선언하는데, generator 메소드가 실행됩니다.


self.fake_B = self.generator(self.real_A)


real_A를 매개변수로 가진 generator를 fake_B에 저장하고 있는데요. generator 메소드를 통해 진행해보도록 하겠습니다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def generator(self, image, y=None):
        with tf.variable_scope("generator") as scope:
 
            s = self.output_size
            s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)
 
            # image is (256 x 256 x input_c_dim)
            e1 = conv2d(image, self.gf_dim, name='g_e1_conv')
            # e1 is (128 x 128 x self.gf_dim)
            e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))
            # e2 is (64 x 64 x self.gf_dim*2)
            e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))
            # e3 is (32 x 32 x self.gf_dim*4)
            e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))
            # e4 is (16 x 16 x self.gf_dim*8)
            e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))
            # e5 is (8 x 8 x self.gf_dim*8)
            e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))
            # e6 is (4 x 4 x self.gf_dim*8)
            e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))
            # e7 is (2 x 2 x self.gf_dim*8)
            e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))
            # e8 is (1 x 1 x self.gf_dim*8)
 
            self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
                [self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)
            d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)
            d1 = tf.concat([d1, e7], 3)
            # d1 is (2 x 2 x self.gf_dim*8*2)
 
            self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
                [self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)
            d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)
            d2 = tf.concat([d2, e6], 3)
            # d2 is (4 x 4 x self.gf_dim*8*2)
 
            self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
                [self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)
            d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)
            d3 = tf.concat([d3, e5], 3)
            # d3 is (8 x 8 x self.gf_dim*8*2)
 
            self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
                [self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)
            d4 = self.g_bn_d4(self.d4)
            d4 = tf.concat([d4, e4], 3)
            # d4 is (16 x 16 x self.gf_dim*8*2)
 
            self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
                [self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)
            d5 = self.g_bn_d5(self.d5)
            d5 = tf.concat([d5, e3], 3)
            # d5 is (32 x 32 x self.gf_dim*4*2)
 
            self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
                [self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)
            d6 = self.g_bn_d6(self.d6)
            d6 = tf.concat([d6, e2], 3)
            # d6 is (64 x 64 x self.gf_dim*2*2)
 
            self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
                [self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)
            d7 = self.g_bn_d7(self.d7)
            d7 = tf.concat([d7, e1], 3)
            # d7 is (128 x 128 x self.gf_dim*1*2)
 
            self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),
                [self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)
            # d8 is (256 x 256 x output_c_dim)
 
            return tf.nn.tanh(self.d8)
cs



image를 convolution하고, 이를 다시 deconvolution해서 return하는 과정



self.fake_B = self.generator(self.real_A) : real_A를 이용한 생성자로 가짜 fake_B 이미지 생성



self.real_AB = tf.concat([self.real_A, self.real_B], 3) : 실제 이미지인 real_A와 real_B를 concat해서 저장한 변수


self.fake_AB = tf.concat([self.real_A, self.fake_B], 3) : real_A에 fake_B를 concat해서 저장한 변수



self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False) : 실제 이미지에 대한 구분자 실행해서 저장


self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True) : 가짜 이미지에 대한 구분자 실행해서 저장


reuse는 변수를 다시 사용할 것인지 아닌지를 나타냄




self.d_sum = tf.summary.histogram("d", self.D) : 실제 이미지 구분자 실행한 값을 저장





1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
        self.d_sum = tf.summary.histogram("d", self.D)
        self.d__sum = tf.summary.histogram("d_", self.D_)
        self.fake_B_sum = tf.summary.image("fake_B", self.fake_B)
 
 
        #구분자의 실제 이미지 loss와 가짜 이미지 loss 구하는 식
        self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \
                        + self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B))
 
        #실제와 가짜 이미지 loss의 각각의 합
        self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
        self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
 
        #loss 총합
        self.d_loss = self.d_loss_real + self.d_loss_fake
 
        self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
        self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
 
        t_vars = tf.trainable_variables()
 
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
 
        self.saver = tf.train.Saver()
cs


반응형