Skip to content

Commit

Permalink
Update normal version
Browse files Browse the repository at this point in the history
  • Loading branch information
Warvito committed Mar 10, 2023
1 parent 15e109b commit 408001c
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 38 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ TODO LIST:
- [ ] Add synthetic sentences based on other source of information
- [ ] Maybe use LLM to augment the reports
- [ ] Add warmup time for the diffusion model
- [ ] Include images from ChestX-ray14 https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345


## C1
Expand Down
7 changes: 1 addition & 6 deletions configs/stage1/aekl_v0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ stage1:
spatial_dims: 2
in_channels: 1
out_channels: 1
num_channels: [64, 128, 128, 128]
num_channels: [64, 128, 128, 256]
latent_channels: 3
num_res_blocks: 2
attention_levels: [False, False, False, False]
Expand All @@ -22,11 +22,6 @@ discriminator:
num_layers_d: 3
in_channels: 1
out_channels: 1
kernel_size: 4
activation: "LEAKYRELU"
norm: "BATCH"
bias: False
padding: 1

perceptual_network:
params:
Expand Down
10 changes: 10 additions & 0 deletions src/python/testing/generate_sample_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,13 @@

plt.imshow(sample.cpu()[0, 0, :, :], cmap="gray", vmin=0, vmax=1)
plt.show()


torch.save(
diffusion.state_dict(),
"/media/walter/Storage/Projects/GenerativeModels/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/models/diffusion_model.pth",
)
torch.save(
stage1.state_dict(),
"/media/walter/Storage/Projects/GenerativeModels/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/models/autoencoder.pth",
)
17 changes: 1 addition & 16 deletions src/python/training/training_functions_old_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from pynvml.smi import nvidia_smi
from tensorboardX import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from training_functions import get_lr, print_gpu_memory_report
from util import log_reconstructions


def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]


def print_gpu_memory_report():
if torch.cuda.is_available():
nvsmi = nvidia_smi.getInstance()
data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"]
print("Memory report")
for i, data_by_rank in enumerate(data):
mem_report = data_by_rank["fb_memory_usage"]
print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}")


# ----------------------------------------------------------------------------------------------------------------------
# AUTOENCODER KL
# ----------------------------------------------------------------------------------------------------------------------
Expand Down
17 changes: 1 addition & 16 deletions src/python/training/training_functions_original_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from pynvml.smi import nvidia_smi
from tensorboardX import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from training_functions import get_lr, print_gpu_memory_report
from util import log_reconstructions


def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]


def print_gpu_memory_report():
if torch.cuda.is_available():
nvsmi = nvidia_smi.getInstance()
data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"]
print("Memory report")
for i, data_by_rank in enumerate(data):
mem_report = data_by_rank["fb_memory_usage"]
print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}")


def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
Expand Down

0 comments on commit 408001c

Please sign in to comment.