Skip to content

Commit

Permalink
Fixing PolynomialBS policy implementation (#12)
Browse files Browse the repository at this point in the history
* The new `PolynomialBS`  implementation has a linear increase in batch size when `power=1.0`, instead of an exponential increase.
* Bumping version
  • Loading branch information
ancestor-mithril authored Jun 5, 2024
1 parent 776eaf2 commit 369f4e5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
11 changes: 6 additions & 5 deletions bs_scheduler/batch_size_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,8 @@ def load_state_dict(self, state_dict: dict):

class PolynomialBS(BSScheduler):
""" Increases the batch size using a polynomial function in the given total_iters. Unlike
torch.optim.lr_scheduler.PolynomialLR whose polynomial factor decays from 1.0 to 0.0, in this case the polynomial
factor increases from 1.0 to 2.0 ** power.
torch.optim.lr_scheduler.PolynomialLR whose polynomial factor decays from 1.0 to 0.5 ** power, in this case the
polynomial factor decays from 1.5 ** power to 1.0.
Args:
dataloader (DataLoader): Wrapped dataloader.
Expand Down Expand Up @@ -858,8 +858,9 @@ def get_new_bs(self) -> int:
self._finished = self.last_epoch >= self.total_iters
return self.batch_size

factor = ((1.0 - (self.last_epoch - 1) / self.total_iters) / (
1.0 - self.last_epoch / self.total_iters)) ** self.power
remaining_steps = self.total_iters - self.last_epoch
factor = 2.0 - ((1.0 - remaining_steps / self.total_iters) / (
1.0 - (remaining_steps - 1) / self.total_iters)) ** self.power
return rint(self.batch_size * factor)


Expand Down Expand Up @@ -933,7 +934,7 @@ def get_new_bs(self) -> int:
else:
new_bs = (1 + math.cos(math.pi * self.last_epoch / self.total_iters)) / (
1 + math.cos(math.pi * (self.last_epoch - 1) / self.total_iters)) * (
self._float_batch_size - self.max_batch_size) + self.max_batch_size
self._float_batch_size - self.max_batch_size) + self.max_batch_size

self._float_batch_size = new_bs
return clip(rint(new_bs), min=self.base_batch_size, max=self.max_batch_size)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "bs_scheduler"
version = "0.4.0"
version = "0.4.1"
requires-python = ">=3.9"
description = "A PyTorch Dataloader compatible batch size scheduler library."
readme = "README.md"
Expand Down
8 changes: 4 additions & 4 deletions tests/test_PolynomialBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_dataloader_lengths(self):
scheduler = PolynomialBS(dataloader, total_iters=total_iters, power=power, verbose=False)

epoch_lengths = simulate_n_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = [10, 12, 16, 24] + [48] * 16
expected_batch_sizes = [10, 15, 20, 25] + [30] * 16
expected_lengths = self.compute_epoch_lengths(expected_batch_sizes, len(self.dataset), drop_last=False)
self.assertEqual(epoch_lengths, expected_lengths)

Expand All @@ -32,10 +32,10 @@ def test_dataloader_batch_size(self):
power = 1.0
n_epochs = 10
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
scheduler = PolynomialBS(dataloader, total_iters=total_iters, power=power, max_batch_size=100, verbose=False)
scheduler = PolynomialBS(dataloader, total_iters=total_iters, power=power, max_batch_size=200, verbose=False)

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = [64, 71, 80, 91, 100, 100, 100, 100, 100, 100]
expected_batch_sizes = [64, 96, 128, 160, 192, 200, 200, 200, 200, 200]

self.assertEqual(batch_sizes, expected_batch_sizes)

Expand Down Expand Up @@ -70,7 +70,7 @@ def test_graphic(self):

model = torch.nn.Linear(10, 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_iters, power=0.1)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_iters, power=1.0)
learning_rates = []

def get_lr(optimizer):
Expand Down

0 comments on commit 369f4e5

Please sign in to comment.