-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Reduce key accessing in transforms #1590
Conversation
Signed-off-by: Matteo Bettini <[email protected]>
tensordict_keys = tensordict.keys(include_nested=True) | ||
for in_key, out_key in zip(self.in_keys, self.out_keys): | ||
if in_key in tensordict.keys(include_nested=True): | ||
if in_key in tensordict_keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, what about making it a set then?
Another easy option is
val = tensordict.get(in_key, default=None)
if val is not None:
...
that way you do the get and assertion in one go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but that would imply multiple gets
my option gets the keys once and for all instead of trying to get each key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can imagine that in nested tensordicts with many levels, doing a get for each key when the _apply_transform
is not implemented is quite expensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why a set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a set to avoid scanning through nested td multiple times
by using a set we just scan it once and then query have a constant time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok we can do it this way
I am not so sure that the keys are more present than they are absent though since if a transform does not implement _call
all keys are guaranteed to be absent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really following this comment
All transforms implement _call
(since it's in the base), if they override _call
to something else this loop isn't executed.
You mean _apply_transform
?
Do you have examples of such transforms? How extensive is the problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes sorry i meant _apply_transform
15 transforms implement it and the others don't
i don't know how many transforms we have but that - 17 is the number that will always get None
i would say majority of transoforms does not implement it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the list of transforms we have
ActionMask => overrides _call
BinarizeReward => reward is missing during reset
CatFrames => no impact
CatTensors => no impact
CenterCrop => no impact
ClipTransform => no impact
Compose => no impact
DeviceCastTransform => no impact
DiscreteActionProjection => no impact
DoubleToFloat => no impact
DTypeCastTransform => no impact
ExcludeTransform => no impact
FiniteTensorDictCheck => no impact
FlattenObservation => no impact
FrameSkipTransform => no impact
GrayScale => no impact
gSDENoise => no impact
InitTracker => no impact
KLRewardTransform => reward is missing during reset
NoopResetEnv => no impact
ObservationNorm => no impact
ObservationTransform => no impact
PermuteTransform => no impact
PinMemoryTransform => no impact
R3MTransform => no impact
RandomCropTensorDict => no impact
RenameTransform => no impact
Resize => no impact
RewardClipping => reward is missing during reset
RewardScaling => reward is missing during reset
RewardSum => reward is missing during reset
Reward2GoTransform => no impact
SelectTransform => no impact
SqueezeTransform => no impact
StepCounter => no impact
TargetReturn => reward is missing during reset
TensorDictPrimer => no impact
TimeMaxPool => no impact
ToTensorImage => no impact
UnsqueezeTransform => no impact
VecGymEnvTransform => no impact
VecNorm => no impact
VC1Transform => no impact
VIPRewardTransform => reward is missing during reset, not sure it has an impact though
VIPTransform => no impact
Which make 6 transforms where this is needed, and 40 where it isn't.
What I understand from your PR (the problem you're trying to solve, even though this isn't stated anywhere so I'm just guessing here) is that you want to reduce the number of accesses to the keys of the td.
What you want also is that if a key is absent things stay efficient, in which case a set could be a good idea.
This only works when we allow missing keys which is only turned on during reset to this is already just a tiny fraction of the calls to _call
.
During these calls, a key can be absent because it can't be found at reset but only during step, which again is something rather peculiar. So we're at the edge case of the rare case. Even in cases where the key is missing, set is only more efficient if there are multiple keys to be checked, which reduces again the subspace where this change would be beneficial.
Roughly, a wild guess would put using set at being beneficial for < 2% of the envs (assuming 6/40 transforms * 1/2 calls to reset -- assuming one reset at each step -- * 1/5 where there are many keys to be checked in the transform).
In all other cases using get(..., None)
is more efficient.
Happy to be shown wrong though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes sorry i meant _apply_transform
15 transforms implement it and the others don't
i don't know how many transforms we have but that - 17 is the number that will always get None
i would say majority of transoforms does not implement it
I could be wrong but I think all that don't implement it do override _call so this PR won't affect them anyway
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
Signed-off-by: Matteo Bettini <[email protected]>
No description provided.