-
Notifications
You must be signed in to change notification settings - Fork 334
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from sony/example/20180318-progressive-growing…
…-of-gans Example of the Progressive Growing of GANs.
- Loading branch information
Showing
12 changed files
with
1,778 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Progressive Growing of GANs | ||
|
||
## Overview | ||
|
||
Reproduction of the work, "Progressive Growing of GANs for Improved Quality, Stability, and Variation" by NNabla. | ||
|
||
### Datasets | ||
|
||
For the training, the following dataset(s) need to be available: | ||
|
||
- [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) | ||
- Download `img_align_celeba_png`. | ||
- Decompress via `7za e img_align_celeba_png.7z.001`. | ||
- ([LSUN](http://www.yf.io/p/lsun) and [LSUN Challenge](http://lsun.cs.princeton.edu/2016/)) | ||
- (CelebA-HQ) | ||
|
||
### Configuration | ||
|
||
In `args.py`, there are configurations for training PGGANs, | ||
generating images, and validating trained models using a certain metric. | ||
|
||
|
||
### Training | ||
|
||
Train the progressive growing of GANs with the following command, | ||
|
||
```python | ||
python train.py --device-id 0 \ | ||
--img-path <path to images> \ | ||
--monitor-path <monitor path> | ||
``` | ||
|
||
It takes about 1 day using the single Tesla V100. | ||
After the training finishes, you can find the parameters of the trained model, | ||
the generated images during the training, the training configuration, | ||
the log of losses, and etc in the `<monitor path>`. | ||
|
||
### Generation | ||
|
||
For generaing images, run | ||
|
||
```python | ||
python generate.py --device-id 0 \ | ||
--model-load-path <path to model> \ | ||
--monitor-path <monitor path> | ||
``` | ||
|
||
The generated images are located the `<monitor path>`. | ||
|
||
### Validation | ||
|
||
Validate models using some metrics. | ||
|
||
```python | ||
python validate.py --device-id 0 \ | ||
--img-path <path to images> \ | ||
--evaluation-metric <swd or ms-ssim> \ | ||
--monitor-path <monitor path> | ||
``` | ||
|
||
The log of the validation metric is located in the `<monitor path>`. | ||
|
||
## NOTE | ||
- Currently, we are using LSGAN. | ||
- [TODO] Some works on LSUN dataset | ||
- [TODO] CelebA-HQ | ||
|
||
## References | ||
|
||
- Tero Karras, Timo Aila, Samuli Laine, Jaakko Lehtinen, "Progressive Growing of GANs for Improved Quality, Stability, and Variation", arXiv:1710.10196. | ||
- https://github.com/tkarras/progressive_growing_of_gans | ||
- https://github.com/tkarras/progressive_growing_of_gans/tree/original-theano-version | ||
|
||
## Acknowledgement | ||
|
||
This work was mostly done by the intern. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) 2017 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def get_args(batch_size=16): | ||
""" | ||
Get command line arguments. | ||
Arguments set the default values of command line arguments. | ||
""" | ||
import argparse | ||
import os | ||
|
||
description = "Example of Progressive Growing of GANs." | ||
parser = argparse.ArgumentParser(description) | ||
|
||
parser.add_argument("-d", "--device-id", type=int, default=0, | ||
help="Device id.") | ||
parser.add_argument("-c", "--context", type=str, default="cudnn", | ||
help="Context.") | ||
parser.add_argument("--type-config", "-t", type=str, default='float', | ||
help='Type of computation. e.g. "float", "half".') | ||
parser.add_argument("--batch-size", "-b", type=int, default=batch_size, | ||
help="Batch size.") | ||
parser.add_argument("--img-path", type=str, | ||
default="~/img_align_celeba_png", | ||
help="Image path.") | ||
parser.add_argument("--dataset-name", type=str, default="CelebA", | ||
choices=["CelebA"], | ||
help="Dataset name used.") | ||
parser.add_argument("--save-image-interval", type=int, default=1, | ||
help="Interval for saving images.") | ||
parser.add_argument("--epoch-per-resolution", type=int, default=4, | ||
help="Number of epochs per resolution.") | ||
parser.add_argument("--imsize", type=int, default=128, | ||
help="Input image size.") | ||
parser.add_argument("--train-samples", type=int, default=-1, | ||
help="Number of data to be used. When -1 is set all data is used.") | ||
parser.add_argument("--valid-samples", type=int, default=16384, | ||
help="Number of data used in validation.") | ||
parser.add_argument("--latent", type=int, default=512, | ||
help="Number of latent variables.") | ||
parser.add_argument("--critic", type=int, default=1, | ||
help="Number of critics.") | ||
parser.add_argument("--monitor-path", type=str, default="./result/example_0", | ||
help="Monitor path.") | ||
parser.add_argument("--model-load-path", type=str, | ||
default="./result/example_0/Gen_phase_128_epoch_4.h5", | ||
help="Model load path used in generation and validation.") | ||
parser.add_argument("--use-bn", action='store_true', | ||
help="Use batch normalization.") | ||
parser.add_argument("--use-ln", action='store_true', | ||
help="Use layer normalization.") | ||
parser.add_argument("--not-use-wscale", action='store_false', | ||
help="Not use the equalized learning rate.") | ||
parser.add_argument("--use-he-backward", action='store_true', | ||
help="Use the He initializaiton using the so-caled `fan_in`. Default is the backward.") | ||
parser.add_argument("--leaky-alpha", type=float, default=0.2, | ||
help="Leaky alpha value.") | ||
parser.add_argument("--learning-rate", type=float, default=0.001, | ||
help="Learning rate.") | ||
parser.add_argument("--beta1", type=float, default=0.0, | ||
help="Beta1 of Adam solver.") | ||
parser.add_argument("--beta2", type=float, default=0.99, | ||
help="Beta2 of Adam solver.") | ||
parser.add_argument("--l2-fake-weight", type=float, default=0.1, | ||
help="Weight for the fake term in the discriminator loss in LSGAN.") | ||
parser.add_argument("--hyper-sphere", action='store_true', | ||
help="Latent vector lie in the hyper sphere.") | ||
parser.add_argument("--last-act", type=str, default="tanh", | ||
choices=["tanh"], | ||
help="Last activation of the generator.") | ||
parser.add_argument("--validation-metric", type=str, default="swd", | ||
choices=["swd", "ms-ssim"], | ||
help="Validation metric for PGGAN.") | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def save_args(args): | ||
from nnabla import logger | ||
import os | ||
if not os.path.exists(args.monitor_path): | ||
os.makedirs(args.monitor_path) | ||
|
||
path = "{}/Arguments.txt".format(args.monitor_path) | ||
logger.info("Arguments are saved to {}.".format(path)) | ||
with open(path, "w") as fp: | ||
for k, v in sorted(vars(args).items()): | ||
logger.info("{}={}".format(k, v)) | ||
fp.write("{}={}\n".format(k, v)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2017 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import glob | ||
from nnabla import logger | ||
from nnabla.utils.data_iterator import data_iterator_simple | ||
import os | ||
from scipy.misc import imread, imresize | ||
import sys | ||
|
||
import numpy as np | ||
|
||
|
||
def data_iterator(img_path, batch_size, | ||
imsize=(128, 128), num_samples=100, shuffle=True, rng=None, dataset_name="CelebA"): | ||
if dataset_name == "CelebA": | ||
di = data_iterator_celeba(img_path, batch_size, | ||
imsize=imsize, num_samples=num_samples, shuffle=shuffle, rng=rng) | ||
else: | ||
logger.info("Currently CelebA is only supported.") | ||
sys.exit(0) | ||
return di | ||
|
||
|
||
def data_iterator_celeba(img_path, batch_size, imsize=(128, 128), num_samples=100, shuffle=True, rng=None): | ||
imgs = glob.glob("{}/*.png".format(img_path)) | ||
if num_samples == -1: | ||
num_samples = len(imgs) | ||
else: | ||
logger.info( | ||
"Num. of data ({}) is used for debugging".format(num_samples)) | ||
|
||
def load_func(i): | ||
cx = 89 | ||
cy = 121 | ||
img = imread(imgs[i]) | ||
img = img[cy - 64: cy + 64, cx - 64: cx + | ||
64, :].transpose(2, 0, 1) / 255. | ||
img = img * 2. - 1. | ||
return img, None | ||
return data_iterator_simple(load_func, num_samples, batch_size, shuffle=shuffle, rng=rng, with_file_cache=False) |
Oops, something went wrong.