Skip to content

Commit

Permalink
Fixed failing tests on pytorch nightly using torch.load
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Nov 7, 2024
1 parent 3c5e213 commit c50b777
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tests/ignite/engine/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from packaging.version import Version

import ignite.distributed as idist
from ignite.engine import Events
Expand Down Expand Up @@ -737,7 +738,11 @@ def write_data_grads_weights(e):
grad_norms.append([i, total[1]] + out2)

if sd is not None:
sd = torch.load(sd)
if Version(torch.__version__) >= Version("1.13.0"):
kwargs = {"weights_only": False}
else:
kwargs = {}
sd = torch.load(sd, **kwargs)
model.load_state_dict(sd[0])
opt.load_state_dict(sd[1])
from ignite.engine.deterministic import _repr_rng_state
Expand Down
6 changes: 5 additions & 1 deletion tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ def test_torch_save_load(dirname):

filepath = Path(dirname) / "dummy_lambda_state_parameter_scheduler.pt"
torch.save(lambda_state_parameter_scheduler, filepath)
loaded_lambda_state_parameter_scheduler = torch.load(filepath)
if Version(torch.__version__) >= Version("1.13.0"):
kwargs = {"weights_only": False}
else:
kwargs = {}
loaded_lambda_state_parameter_scheduler = torch.load(filepath, **kwargs)

engine1 = Engine(lambda e, b: None)
lambda_state_parameter_scheduler.attach(engine1, Events.EPOCH_COMPLETED)
Expand Down

0 comments on commit c50b777

Please sign in to comment.