From b1dc3054a792721ea1f61a139a5befe5acf00f4a Mon Sep 17 00:00:00 2001 From: mdraw Date: Tue, 5 Feb 2019 03:43:08 +0100 Subject: [PATCH] Make drop_prob mutable in DropBlock ScriptModules This makes it possible to schedule drop_prob during training when used in JIT mode. The LinearScheduler module still can't be used in JIT mode yet, but you can change drop_prob from outside of the model, i.e. in your training loop. --- dropblock/dropblock.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dropblock/dropblock.py b/dropblock/dropblock.py index b5797cc..2fbeabf 100644 --- a/dropblock/dropblock.py +++ b/dropblock/dropblock.py @@ -33,12 +33,12 @@ class DropBlock2D(Module): """ - __constants__ = ['drop_prob', 'block_size', 'gamma'] + __constants__ = ['block_size', 'gamma'] def __init__(self, drop_prob, block_size): super(DropBlock2D, self).__init__() - self.drop_prob = drop_prob # type: float + self.register_buffer('drop_prob', torch.tensor(drop_prob, dtype=torch.float32)) self.block_size = block_size # type: int # get gamma value @@ -51,7 +51,7 @@ def forward(self, x): assert x.dim() == 4, \ "Expected input with 4 dimensions (bsize, channels, height, width)" - if not self.training or self.drop_prob == 0.: + if not self.training or bool(self.drop_prob == torch.zeros(())): out = x else: # sample mask @@ -83,7 +83,7 @@ def _compute_block_mask(self, mask): return block_mask def _compute_gamma(self): - return self.drop_prob / (self.block_size ** 2) + return self.drop_prob.item() / (self.block_size ** 2) class DropBlock3D(Module): @@ -107,12 +107,12 @@ class DropBlock3D(Module): """ - __constants__ = ['drop_prob', 'block_size', 'gamma'] + __constants__ = ['block_size', 'gamma'] def __init__(self, drop_prob, block_size): super(DropBlock3D, self).__init__() - self.drop_prob = drop_prob # type: float + self.register_buffer('drop_prob', torch.tensor(drop_prob, dtype=torch.float32)) self.block_size = block_size # type: int # get gamma value @@ -125,7 +125,7 @@ def forward(self, x): assert x.dim() == 5, \ "Expected input with 5 dimensions (bsize, channels, depth, height, width)" - if not self.training or self.drop_prob == 0.: + if not self.training or bool(self.drop_prob == torch.zeros(())): out = x else: # sample mask @@ -157,4 +157,4 @@ def _compute_block_mask(self, mask): return block_mask def _compute_gamma(self): - return self.drop_prob / (self.block_size ** 3) + return self.drop_prob.item() / (self.block_size ** 3)