Skip to content

Commit

Permalink
Merge pull request #10 from valentingol/dev
Browse files Browse the repository at this point in the history
🆙 Release 0.2.0
  • Loading branch information
valentingol authored Jul 25, 2022
2 parents c477ece + 1383ad8 commit d3645d5
Show file tree
Hide file tree
Showing 19 changed files with 373 additions and 222 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
configs/runs/
res/
wandb/

**tmp**
63 changes: 63 additions & 0 deletions apps/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Test the generator."""

import os
import os.path as osp

import numpy as np
import torch
from PIL import Image

from utils.configs import ConfigType, GlobalConfig
from utils.data.process import to_img_grid
from utils.sagan.modules import SAGenerator


def test(config: ConfigType) -> None:
"""Test the generator."""
architecture = config.model.architecture
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = config.test_batch_size
dataset = np.load(config.dataset_path)
n_classes = dataset.max() + 1

model_dir = osp.join('res', config.run_name, 'models')
step = config.recover_model_step
if step <= 0:
model_path = osp.join(model_dir, 'generator_last.pth')
else:
model_path = osp.join(model_dir, f'generator_step_{step}.pth')

if architecture == 'sagan':
generator = SAGenerator(n_classes=n_classes,
data_size=config.model.data_size,
z_dim=config.model.z_dim,
conv_dim=config.model.g_conv_dim).to(device)

z_input = torch.randn(batch_size, config.model.z_dim, device=device)
generator.load_state_dict(torch.load(model_path))
generator.eval()
with torch.no_grad():
images, attn_list = generator.generate(z_input, with_attn=True)
# Save sample images in a grid
img_out_path = osp.join('res', config.run_name, 'samples',
'test_samples.png')
img_grid = to_img_grid(images)
pil_images = Image.fromarray(img_grid)
pil_images.show(title='Test Samples (run ' + config.run_name + ')')
pil_images.save(img_out_path)

if config.save_attn:
# Save attention
attn_out_path = osp.join('res', config.run_name, 'attention',
'test_gen_attn')
os.makedirs(attn_out_path, exist_ok=True)
attn_list = [attn.detach().cpu().numpy() for attn in attn_list]
for i, attn in enumerate(attn_list):
np.save(osp.join(attn_out_path, f'attn_{i}.npy'), attn)


if __name__ == '__main__':
global_config = GlobalConfig.build_from_argv(
fallback='configs/exp/base.yaml')
# NOTE: The config is not saved when testing only
test(global_config)
70 changes: 34 additions & 36 deletions apps/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Train and test the SAGAN model."""
"""Train the GAN."""

import os.path as osp

import torch
import wandb
from torch.backends import cudnn

Expand All @@ -10,43 +13,37 @@

def train(config: ConfigType) -> None:
"""Train and test the SAGAN model."""
# For fast training
cudnn.benchmark = True
if not torch.cuda.is_available():
raise ValueError('CUDA is not available and is required for training.')

if config.train:
batch_size = config.training.batch_size
else:
# TODO: add test batch size
raise NotImplementedError('Test not implemented yet! (TODO)')
batch_size = config.training.batch_size
architecture = config.model.architecture

# For fast training
cudnn.benchmark = True
# Data loader
data_loader = DataLoader2DFacies(dataset_path=config.dataset_path,
data_size=config.model.data_size,
batch_size=batch_size, shuffle=True,
num_workers=config.num_workers).loader()

if config.train:
architecture = config.model.architecture
if architecture == 'sagan':
trainer = TrainerSAGAN(data_loader, config)
else:
raise NotImplementedError(f'Architecture "{architecture}" '
'is not implemented!')
trainer.train()

# TODO: test the model here
# Model
if architecture == 'sagan':
trainer = TrainerSAGAN(data_loader, config)
else:
raise NotImplementedError(f'Architecture "{architecture}" '
'is not implemented!')
# Train
trainer.train()


def train_wandb() -> None:
"""Run the train using wandb."""
wandb.init(
config=global_config.get_dict(),
entity=global_config.wandb.entity,
project=global_config.wandb.project,
mode=global_config.wandb.mode,
group=global_config.wandb.group,
dir='./wandb_metadata',
)
wandb.init(config=global_config.get_dict(),
entity=global_config.wandb.entity,
project=global_config.wandb.project,
mode=global_config.wandb.mode, group=global_config.wandb.group,
dir='./wandb_metadata',
)
if global_config.wandb.sweep is None:
# No sweep, run the train with global config
train(global_config)
Expand All @@ -57,9 +54,11 @@ def train_wandb() -> None:
# (under format 'sub.config.key': value)
config_updated = {**global_config.get_dict(), **dict(wandb.config)}
# Avoid re-initializing sub-configs with preprocess routines
config_updated = {key: val for key, val in config_updated.items()
if not (key.endswith('config_path')
or key == 'config_save_path')}
config_updated = {
key: val
for key, val in config_updated.items()
if not (key.endswith('config_path') or key == 'config_save_path')
}
# Apply the merge
config = GlobalConfig.load_config(config_updated,
do_not_merge_command_line=True,
Expand All @@ -71,11 +70,10 @@ def main() -> None:
"""Run the train using wandb (+sweep) or not."""
if global_config.wandb.use_wandb:
if global_config.wandb.sweep is not None:
sweep_id = wandb.sweep(
sweep=global_config.wandb.sweep,
entity=global_config.wandb.entity,
project=global_config.wandb.project,
)
sweep_id = wandb.sweep(sweep=global_config.wandb.sweep,
entity=global_config.wandb.entity,
project=global_config.wandb.project,
)
wandb.agent(sweep_id, function=train_wandb)
else:
train_wandb()
Expand All @@ -86,5 +84,5 @@ def main() -> None:
if __name__ == '__main__':
global_config = GlobalConfig.build_from_argv(
fallback='configs/exp/base.yaml')
global_config.save(global_config.config_save_path)
global_config.save(osp.join(global_config.config_save_path, 'config'))
main()
2 changes: 1 addition & 1 deletion configs/default/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
run_name: null # (should be overwritten by the experiment config)
config_save_path: null # (should be overwritten by the experiment config)
dataset_path: null # (should be overwritten by the experiment config)
train: True # if False, analyse a recovered model (see recover_model_step)
use_wandb: False
recover_model_step: 0 # the step to recover the model, 0 to not recover
save_attn: False
# num_workers: number of worker to process the data in parallel
# (0 to not apply parallelism)
num_workers: 0
test_batch_size: 64

# Additional configs
model_config_path: configs/default/model.yaml
Expand Down
9 changes: 5 additions & 4 deletions configs/unittest/data32.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ config_save_path: configs/runs/tmp_test/config
dataset_path: null # use custom simple dataset instead
save_attn: True

wandb.use_wandb: False
model.data_size: 32

training.adv_loss: hinge
training.total_step: 2
training.batch_size: 2
training.log_step: 1
training.sample_step: 2
training.model_save_step: 2
training.sample_step: 2
training.total_step: 2

model.data_size: 32
wandb.use_wandb: False
10 changes: 6 additions & 4 deletions configs/unittest/data64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ config_save_path: configs/runs/tmp_test/config
dataset_path: null # use custom simple dataset instead
save_attn: True

wandb.use_wandb: False
model.data_size: 64

training.total_step: 2
training.adv_loss: wgan-gp
training.batch_size: 2
training.log_step: 1
training.sample_step: 2
training.model_save_step: 2
training.sample_step: 2
training.total_step: 2

model.data_size: 64
wandb.use_wandb: False
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Installation
config = {
'name': 'sagan-facies-modeling',
'version': '0.1.0',
'version': '0.2.0',
'description': 'Facies modeling with SAGAN.',
'author': 'Valentin Goldite',
'author_email': '[email protected]',
Expand Down
1 change: 1 addition & 0 deletions tests/utils/data/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for utils/data/data_loader.py."""

import os

import numpy as np
Expand Down
3 changes: 1 addition & 2 deletions tests/utils/github_actions/test_pydocstyle_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_check_output() -> None:
sys.argv = ['--n_errors=0']
check_output()
sys.argv = ['--n_errors=1']
with pytest.raises(ValueError,
match='.*found 1 error.*'):
with pytest.raises(ValueError, match='.*found 1 error.*'):
check_output()
sys.argv = old_argv

Expand Down
6 changes: 6 additions & 0 deletions tests/utils/sagan/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ def test_sa_generator() -> None:
assert len(att_list) == 2
assert att_list[0].shape == (1, 256, 256)
assert att_list[1].shape == (1, 1024, 1024)
# Test generate method
images, _ = gen.generate(z, with_attn=True)
assert images.shape == (1, 64, 64, 3)
images, attn_list = gen.generate(z, with_attn=False)
assert images.shape == (1, 64, 64, 3)
assert not attn_list
2 changes: 1 addition & 1 deletion tests/utils/sagan/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_train(data_loaders: Tuple[DataLoader, DataLoader]) -> None:
assert osp.exists('res/tmp_test/models/generator_step_2.pth')
assert osp.exists('res/tmp_test/models/discriminator_step_2.pth')
assert osp.exists('res/tmp_test/samples/images_step_2.png')
assert osp.exists('res/tmp_test/attention/gen_attn0_step_2.npy')
assert osp.exists('res/tmp_test/attention/gen_attn_step_2/attn_0.npy')
# Remove tmp folders
shutil.rmtree('res/tmp_test')

Expand Down
1 change: 1 addition & 0 deletions tests/utils/test_configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for utils/config.py."""

import sys

from utils.configs import GlobalConfig
Expand Down
8 changes: 4 additions & 4 deletions utils/github_actions/color.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Color utils for GitHub actions."""

from colorsys import hsv_to_rgb


def score_to_hex_color(
score: float, score_min: float, score_max: float
) -> str:
"""Convert score to hex color red > brightgreen."""
def score_to_hex_color(score: float, score_min: float,
score_max: float) -> str:
"""Convert score to hex color red -> bright green."""
norm_score = max(0, (score-score_min) / (score_max-score_min))
hsv = (1 / 3 * norm_score, 1, 1)
rgb = hsv_to_rgb(*hsv)
Expand Down
6 changes: 2 additions & 4 deletions utils/github_actions/pydocstyle_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ def check_output() -> None:
n_errors = int(arg.split('=')[1])

if n_errors > 0:
raise ValueError(
f'Pydocstyle found {n_errors} errors in python '
'docstrings. Please fix them.'
)
raise ValueError(f'Pydocstyle found {n_errors} errors in python '
'docstrings. Please fix them.')


if __name__ == '__main__':
Expand Down
7 changes: 3 additions & 4 deletions utils/github_actions/pylint_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Manage Pylint output on workflow."""

import sys
from typing import Tuple

Expand Down Expand Up @@ -26,10 +27,8 @@ def check_output() -> Tuple[float, float]:
score_min = float(arg.split('=')[1])

if score < score_min:
raise ValueError(
f'Pylint score {score} is lower than '
f'minimum ({score_min}).'
)
raise ValueError(f'Pylint score {score} is lower than '
f'minimum ({score_min}).')

return score, score_min

Expand Down
14 changes: 14 additions & 0 deletions utils/sagan/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch import nn

from utils.data.process import color_data_np
from utils.sagan.spectral import SpectralNorm

TensorWithAttn = Tuple[torch.Tensor, List[torch.Tensor]]
Expand Down Expand Up @@ -230,3 +231,16 @@ def forward(self, z: torch.Tensor) -> TensorWithAttn:
x = self.conv_last(x)
x = nn.Softmax(dim=1)(x)
return x, att_list

def generate(self, z_input: torch.Tensor,
with_attn: bool = False) -> Tuple[np.ndarray,
List[torch.Tensor]]:
"""Return generated images and eventually attention list."""
out, attn_list = self.forward(z_input)
# Quantize + color generated data
out = torch.argmax(out, dim=1)
out = out.detach().cpu().numpy()
images = color_data_np(out)
if with_attn:
return images, attn_list
return images, []
Loading

0 comments on commit d3645d5

Please sign in to comment.