Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Reduce key accessing in transforms #1590

Merged
merged 3 commits into from
Oct 2, 2023

Conversation

matteobettini
Copy link
Contributor

No description provided.

Signed-off-by: Matteo Bettini <[email protected]>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2023
@matteobettini matteobettini changed the title [Performance] Reduce key accessing in tranforms [Performance] Reduce key accessing in transforms Oct 2, 2023
Comment on lines 211 to 213
tensordict_keys = tensordict.keys(include_nested=True)
for in_key, out_key in zip(self.in_keys, self.out_keys):
if in_key in tensordict.keys(include_nested=True):
if in_key in tensordict_keys:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, what about making it a set then?
Another easy option is

val = tensordict.get(in_key, default=None)
if val is not None:
    ...

that way you do the get and assertion in one go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but that would imply multiple gets

my option gets the keys once and for all instead of trying to get each key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can imagine that in nested tensordicts with many levels, doing a get for each key when the _apply_transform is not implemented is quite expensive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why a set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a set to avoid scanning through nested td multiple times
by using a set we just scan it once and then query have a constant time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok we can do it this way

I am not so sure that the keys are more present than they are absent though since if a transform does not implement _call all keys are guaranteed to be absent

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really following this comment
All transforms implement _call (since it's in the base), if they override _call to something else this loop isn't executed.
You mean _apply_transform?

Do you have examples of such transforms? How extensive is the problem?

Copy link
Contributor Author

@matteobettini matteobettini Oct 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sorry i meant _apply_transform

15 transforms implement it and the others don't
i don't know how many transforms we have but that - 17 is the number that will always get None

i would say majority of transoforms does not implement it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the list of transforms we have

    ActionMask => overrides _call
    BinarizeReward => reward is missing during reset
    CatFrames => no impact
    CatTensors => no impact
    CenterCrop => no impact
    ClipTransform => no impact
    Compose => no impact
    DeviceCastTransform => no impact
    DiscreteActionProjection => no impact
    DoubleToFloat => no impact
    DTypeCastTransform => no impact
    ExcludeTransform => no impact
    FiniteTensorDictCheck => no impact
    FlattenObservation => no impact
    FrameSkipTransform => no impact
    GrayScale => no impact
    gSDENoise => no impact
    InitTracker => no impact
    KLRewardTransform => reward is missing during reset
    NoopResetEnv => no impact
    ObservationNorm => no impact
    ObservationTransform => no impact
    PermuteTransform => no impact
    PinMemoryTransform => no impact
    R3MTransform => no impact
    RandomCropTensorDict => no impact
    RenameTransform => no impact
    Resize => no impact
    RewardClipping => reward is missing during reset
    RewardScaling => reward is missing during reset
    RewardSum => reward is missing during reset
    Reward2GoTransform => no impact
    SelectTransform => no impact
    SqueezeTransform => no impact
    StepCounter => no impact
    TargetReturn => reward is missing during reset
    TensorDictPrimer => no impact
    TimeMaxPool => no impact
    ToTensorImage => no impact
    UnsqueezeTransform => no impact
    VecGymEnvTransform => no impact
    VecNorm => no impact
    VC1Transform => no impact
    VIPRewardTransform => reward is missing during reset, not sure it has an impact though
    VIPTransform => no impact

Which make 6 transforms where this is needed, and 40 where it isn't.

What I understand from your PR (the problem you're trying to solve, even though this isn't stated anywhere so I'm just guessing here) is that you want to reduce the number of accesses to the keys of the td.
What you want also is that if a key is absent things stay efficient, in which case a set could be a good idea.
This only works when we allow missing keys which is only turned on during reset to this is already just a tiny fraction of the calls to _call.
During these calls, a key can be absent because it can't be found at reset but only during step, which again is something rather peculiar. So we're at the edge case of the rare case. Even in cases where the key is missing, set is only more efficient if there are multiple keys to be checked, which reduces again the subspace where this change would be beneficial.
Roughly, a wild guess would put using set at being beneficial for < 2% of the envs (assuming 6/40 transforms * 1/2 calls to reset -- assuming one reset at each step -- * 1/5 where there are many keys to be checked in the transform).
In all other cases using get(..., None) is more efficient.

Happy to be shown wrong though!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sorry i meant _apply_transform

15 transforms implement it and the others don't
i don't know how many transforms we have but that - 17 is the number that will always get None

i would say majority of transoforms does not implement it

I could be wrong but I think all that don't implement it do override _call so this PR won't affect them anyway

Signed-off-by: Matteo Bettini <[email protected]>
@vmoens vmoens added the performance Performance issue or suggestion for improvement label Oct 2, 2023
Signed-off-by: Matteo Bettini <[email protected]>
@vmoens vmoens merged commit 3785609 into pytorch:main Oct 2, 2023
38 of 47 checks passed
vmoens pushed a commit to hyerra/rl that referenced this pull request Oct 10, 2023
@matteobettini matteobettini deleted the fix_transofrm branch December 4, 2023 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. performance Performance issue or suggestion for improvement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants