diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 5404beb0ec0..b701b2f6bf7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4311,6 +4311,34 @@ def ndim(self): def ndimension(self): return len(self.shape) + def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any: + """Removes and returns the value associated with the specified key from the composite spec. + + This method searches for the given key in the composite spec, removes it, and returns its associated value. + If the key is not found, it returns the provided default value if specified, otherwise raises a `KeyError`. + + Args: + key (NestedKey): + The key to be removed from the composite spec. It can be a single key or a nested key. + default (Any, optional): + The value to return if the specified key is not found in the composite spec. + If not provided and the key is not found, a `KeyError` is raised. + + Returns: + Any: The value associated with the specified key that was removed from the composite spec. + + Raises: + KeyError: If the specified key is not found in the composite spec and no default value is provided. + """ + key = unravel_key(key) + if key in self.keys(True, True): + result = self[key] + del self[key] + return result + elif default is not NO_DEFAULT: + return default + raise KeyError(f"{key} not found in composite spec.") + def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.")