Skip to content

Commit

Permalink
Merge pull request #14 from sony/example/20180318-progressive-growing…
Browse files Browse the repository at this point in the history
…-of-gans

Example of the Progressive Growing of GANs.
  • Loading branch information
StefanUhlich-sony authored May 14, 2018
2 parents 1fc01b9 + b07ec7f commit 245c679
Show file tree
Hide file tree
Showing 12 changed files with 1,778 additions and 0 deletions.
77 changes: 77 additions & 0 deletions GANs/pggan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Progressive Growing of GANs

## Overview

Reproduction of the work, "Progressive Growing of GANs for Improved Quality, Stability, and Variation" by NNabla.

### Datasets

For the training, the following dataset(s) need to be available:

- [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
- Download `img_align_celeba_png`.
- Decompress via `7za e img_align_celeba_png.7z.001`.
- ([LSUN](http://www.yf.io/p/lsun) and [LSUN Challenge](http://lsun.cs.princeton.edu/2016/))
- (CelebA-HQ)

### Configuration

In `args.py`, there are configurations for training PGGANs,
generating images, and validating trained models using a certain metric.


### Training

Train the progressive growing of GANs with the following command,

```python
python train.py --device-id 0 \
--img-path <path to images> \
--monitor-path <monitor path>
```

It takes about 1 day using the single Tesla V100.
After the training finishes, you can find the parameters of the trained model,
the generated images during the training, the training configuration,
the log of losses, and etc in the `<monitor path>`.

### Generation

For generaing images, run

```python
python generate.py --device-id 0 \
--model-load-path <path to model> \
--monitor-path <monitor path>
```

The generated images are located the `<monitor path>`.

### Validation

Validate models using some metrics.

```python
python validate.py --device-id 0 \
--img-path <path to images> \
--evaluation-metric <swd or ms-ssim> \
--monitor-path <monitor path>
```

The log of the validation metric is located in the `<monitor path>`.

## NOTE
- Currently, we are using LSGAN.
- [TODO] Some works on LSUN dataset
- [TODO] CelebA-HQ

## References

- Tero Karras, Timo Aila, Samuli Laine, Jaakko Lehtinen, "Progressive Growing of GANs for Improved Quality, Stability, and Variation", arXiv:1710.10196.
- https://github.com/tkarras/progressive_growing_of_gans
- https://github.com/tkarras/progressive_growing_of_gans/tree/original-theano-version

## Acknowledgement

This work was mostly done by the intern.

104 changes: 104 additions & 0 deletions GANs/pggan/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_args(batch_size=16):
"""
Get command line arguments.
Arguments set the default values of command line arguments.
"""
import argparse
import os

description = "Example of Progressive Growing of GANs."
parser = argparse.ArgumentParser(description)

parser.add_argument("-d", "--device-id", type=int, default=0,
help="Device id.")
parser.add_argument("-c", "--context", type=str, default="cudnn",
help="Context.")
parser.add_argument("--type-config", "-t", type=str, default='float',
help='Type of computation. e.g. "float", "half".')
parser.add_argument("--batch-size", "-b", type=int, default=batch_size,
help="Batch size.")
parser.add_argument("--img-path", type=str,
default="~/img_align_celeba_png",
help="Image path.")
parser.add_argument("--dataset-name", type=str, default="CelebA",
choices=["CelebA"],
help="Dataset name used.")
parser.add_argument("--save-image-interval", type=int, default=1,
help="Interval for saving images.")
parser.add_argument("--epoch-per-resolution", type=int, default=4,
help="Number of epochs per resolution.")
parser.add_argument("--imsize", type=int, default=128,
help="Input image size.")
parser.add_argument("--train-samples", type=int, default=-1,
help="Number of data to be used. When -1 is set all data is used.")
parser.add_argument("--valid-samples", type=int, default=16384,
help="Number of data used in validation.")
parser.add_argument("--latent", type=int, default=512,
help="Number of latent variables.")
parser.add_argument("--critic", type=int, default=1,
help="Number of critics.")
parser.add_argument("--monitor-path", type=str, default="./result/example_0",
help="Monitor path.")
parser.add_argument("--model-load-path", type=str,
default="./result/example_0/Gen_phase_128_epoch_4.h5",
help="Model load path used in generation and validation.")
parser.add_argument("--use-bn", action='store_true',
help="Use batch normalization.")
parser.add_argument("--use-ln", action='store_true',
help="Use layer normalization.")
parser.add_argument("--not-use-wscale", action='store_false',
help="Not use the equalized learning rate.")
parser.add_argument("--use-he-backward", action='store_true',
help="Use the He initializaiton using the so-caled `fan_in`. Default is the backward.")
parser.add_argument("--leaky-alpha", type=float, default=0.2,
help="Leaky alpha value.")
parser.add_argument("--learning-rate", type=float, default=0.001,
help="Learning rate.")
parser.add_argument("--beta1", type=float, default=0.0,
help="Beta1 of Adam solver.")
parser.add_argument("--beta2", type=float, default=0.99,
help="Beta2 of Adam solver.")
parser.add_argument("--l2-fake-weight", type=float, default=0.1,
help="Weight for the fake term in the discriminator loss in LSGAN.")
parser.add_argument("--hyper-sphere", action='store_true',
help="Latent vector lie in the hyper sphere.")
parser.add_argument("--last-act", type=str, default="tanh",
choices=["tanh"],
help="Last activation of the generator.")
parser.add_argument("--validation-metric", type=str, default="swd",
choices=["swd", "ms-ssim"],
help="Validation metric for PGGAN.")

args = parser.parse_args()

return args


def save_args(args):
from nnabla import logger
import os
if not os.path.exists(args.monitor_path):
os.makedirs(args.monitor_path)

path = "{}/Arguments.txt".format(args.monitor_path)
logger.info("Arguments are saved to {}.".format(path))
with open(path, "w") as fp:
for k, v in sorted(vars(args).items()):
logger.info("{}={}".format(k, v))
fp.write("{}={}\n".format(k, v))
53 changes: 53 additions & 0 deletions GANs/pggan/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import glob
from nnabla import logger
from nnabla.utils.data_iterator import data_iterator_simple
import os
from scipy.misc import imread, imresize
import sys

import numpy as np


def data_iterator(img_path, batch_size,
imsize=(128, 128), num_samples=100, shuffle=True, rng=None, dataset_name="CelebA"):
if dataset_name == "CelebA":
di = data_iterator_celeba(img_path, batch_size,
imsize=imsize, num_samples=num_samples, shuffle=shuffle, rng=rng)
else:
logger.info("Currently CelebA is only supported.")
sys.exit(0)
return di


def data_iterator_celeba(img_path, batch_size, imsize=(128, 128), num_samples=100, shuffle=True, rng=None):
imgs = glob.glob("{}/*.png".format(img_path))
if num_samples == -1:
num_samples = len(imgs)
else:
logger.info(
"Num. of data ({}) is used for debugging".format(num_samples))

def load_func(i):
cx = 89
cy = 121
img = imread(imgs[i])
img = img[cy - 64: cy + 64, cx - 64: cx +
64, :].transpose(2, 0, 1) / 255.
img = img * 2. - 1.
return img, None
return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False)
Loading

0 comments on commit 245c679

Please sign in to comment.