Skip to content

Commit

Permalink
reset notensordict test
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 3, 2023
1 parent 8d36787 commit a24ab8d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2419,8 +2419,10 @@ def test_td3_notensordict(
loss_val = loss(**kwargs)
for i in loss_val:
assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}"
# for i, key in enumerate(loss_val_td.keys()):
# torch.testing.assert_close(loss_val_td.get(key), loss_val[i])

for i, key in enumerate(loss.out_keys):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])

# test select
loss.select_out_keys("loss_actor", "loss_qvalue")
torch.manual_seed(0)
Expand Down

0 comments on commit a24ab8d

Please sign in to comment.