Skip to content

Commit

Permalink
Added use_ema option to diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
abarankab committed Jul 21, 2021
1 parent 3b26994 commit 90d0389
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions ddpm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
model,
img_size,
img_channels,
num_classes,
betas,
loss_type="l1",
ema_decay=0.995,
Expand All @@ -50,6 +51,7 @@ def __init__(

self.img_size = img_size
self.img_channels = img_channels
self.num_classes = num_classes

if loss_type not in ["l1", "l2"]:
raise ValueError("__init__() got unknown loss type")
Expand Down Expand Up @@ -83,45 +85,51 @@ def update_ema(self):
self.step += 1

@torch.no_grad()
def remove_noise(self, x, t, y):
return (
(x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *
extract(self.reciprocal_sqrt_alphas, t, x.shape)
)
def remove_noise(self, x, t, y, use_ema=True):
if use_ema:
return (
(x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *
extract(self.reciprocal_sqrt_alphas, t, x.shape)
)
else:
return (
(x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *
extract(self.reciprocal_sqrt_alphas, t, x.shape)
)

@torch.no_grad()
def sample(self, batch_size, device, y=None):
def sample(self, batch_size, device, y=None, use_ema=True):
if y is not None and batch_size != len(y):
raise ValueError("sample batch size different from length of given y")

if y is None:
y = torch.randint(num_classes, batch_size, device=device)
y = torch.randint(self.num_classes, batch_size, device=device)

x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)

for t in range(self.num_timesteps - 1, -1, -1):
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y)
x = self.remove_noise(x, t_batch, y, use_ema)

if t > 1:
x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)

return x.cpu().detach()

@torch.no_grad()
def sample_diffusion_sequence(self, batch_size, device, y=None):
def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
if y is not None and batch_size != len(y):
raise ValueError("sample batch size different from length of given y")

if y is None:
y = torch.randint(num_classes, batch_size, device=device)
y = torch.randint(self.num_classes, batch_size, device=device)

x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
diffusion_sequence = [x.cpu().detach()]

for t in range(self.num_timesteps - 1, -1, -1):
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y)
x = self.remove_noise(x, t_batch, y, use_ema)

if t > 0:
x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
Expand Down

0 comments on commit 90d0389

Please sign in to comment.