Skip to content

Commit

Permalink
Remove interval arg, group_tensors fn, and on pred epoch end writing.…
Browse files Browse the repository at this point in the history
… Add decollate_batch when writing.
  • Loading branch information
ibro45 committed Aug 14, 2023
1 parent 4a13721 commit 2ea66bd
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 124 deletions.
53 changes: 25 additions & 28 deletions lighter/callbacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,51 +94,48 @@ def parse_data(
return result


def group_tensors(
inputs: Union[List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]]
) -> Union[List, Dict]:
"""Recursively group tensors. Tensors can be standalone or inside of other data structures (list/tuple/dict).
An input list of tensors is returned as-is. Given an input list of data structures with tensors, this function
will group all tensors into a list and save it under a single data structure. Assumes that all elements of
the input list have the same type and structure.
def parse_format(format: str, parsed_preds: Dict[str, Any]) -> Dict[str, str]:
"""
Parse the given format and align it with the structure of the predictions.
If the format is a single string, all predictions will be saved in this format. If the format has a structure
(like a dictionary), it needs to match the structure of the predictions.
Args:
inputs (List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]):
They can be:
- List/Tuples of Dictionaries, each containing tensors to be grouped by their key.
- List/Tuples of Lists/tuples, each containing tensors to be grouped by their position.
- List/Tuples of Tensors, returned as-is.
- Nested versions of the above.
The input data structure must be the same for all elements of the list. They can be arbitrarily nested.
format (str): The storage format for the predictions, either as a string or a structured format.
parsed_preds (Dict[str, Any]): Dictionary of parsed prediction data.
Returns:
Union[List, Dict]: The grouped tensors.
Dict[str, str]: Dictionary of parsed format data corresponding to the prediction structure.
Raises:
ValueError: If the structure of the format does not align with the prediction structure.
"""
# List of dicts.
if isinstance(inputs[0], dict):
keys = inputs[0].keys()
return {key: group_tensors([input[key] for input in inputs]) for key in keys}
# List of lists or tuples.
elif isinstance(inputs[0], (list, tuple)):
return [group_tensors([input[idx] for input in inputs]) for idx in range(len(inputs[0]))]
# List of tensors.
elif isinstance(inputs[0], torch.Tensor):
return inputs
if isinstance(format, str):
# Assign the single format to all prediction keys.
parsed_format = {key: format for key in parsed_preds}
else:
raise TypeError(f"Type `{type(inputs[0])}` not supported.")
# Ensure the structured format corresponds with the predictions' structure.
parsed_format = parse_data(format)
if not set(parsed_format) == set(parsed_preds):
raise ValueError("`format` structure does not match the prediction's structure.")
return parsed_format


def preprocess_image(image: torch.Tensor) -> torch.Tensor:
def preprocess_image(image: torch.Tensor, add_batch_dim=False) -> torch.Tensor:
"""Preprocess the image before logging it. If it is a batch of multiple images,
it will create a grid image of them. In case of 3D, a single image is displayed
with slices stacked vertically, while a batch of 3D images as a grid where each
column is a different 3D image.
Args:
image (torch.Tensor): 2D or 3D image tensor.
add_batch_dim (bool, optional): Whether to add a batch dimension to the input image.
Use only when the input image does not have a batch dimension. Defaults to False.
Returns:
torch.Tensor: image ready for logging.
"""
image = image.detach().cpu()
if add_batch_dim:
image = image.unsqueeze(0)
# If 3D (BCDHW), concat the images vertically and horizontally.
if image.ndim == 5:
shape = image.shape
Expand Down
90 changes: 16 additions & 74 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,37 @@

import torch
from loguru import logger
from monai.data.utils import decollate_batch
from pytorch_lightning import Callback, Trainer

from lighter import LighterSystem
from lighter.callbacks.utils import group_tensors, parse_data
from lighter.callbacks.utils import parse_data, parse_format


class LighterBaseWriter(ABC, Callback):
"""
Base class for defining custom Writer. It provides the structure to save predictions in various formats.
Subclasses should implement:
1) `self._writers` property to specify the supported formats and their corresponding writer functions.
1) `self._writers` attribute to specify the supported formats and their corresponding writer functions.
2) `self.write()` method to specify the saving strategy for a prediction.
Args:
directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside.
format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): Desired format(s) for saving predictions.
The format will be passed to the `write` method.
interval (str, optional): Specifies when to save predictions - at every step or at the end of epoch. Defaults to "step".
additional_writers (Optional[Dict[str, Callable]]): Additional writer functions to be registered with the base writer.
"""

def __init__(
self,
directory: str,
format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]],
interval: str,
additional_writers: Optional[Dict[str, Callable]] = None,
) -> None:
# Create a unique directory using the current date and time
self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S")
self.format = format
self.interval = interval

# Placeholder for processed format for quicker access during writes
self.parsed_format = None
Expand All @@ -66,12 +64,14 @@ def write(
format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]],
) -> None:
"""
Method to define how a tensor should be saved.
Method to define how a tensor should be saved. The input tensor will be a single tensor without
the batch dimension. If the batch dimension is needed, apply `tensor.unsqueeze(0)` before saving,
either in this method or in the particular writer function.
Depending on the specified format, this method should contain logic to handle the saving mechanism.
Args:
tensor (torch.Tensor): Tensor to be saved.
tensor (torch.Tensor): Tensor to be saved. It will be a single tensor without the batch dimension.
id (int): Identifier for the tensor, can be used for naming or indexing.
multi_pred_id (Optional[str]): Used when there are multiple predictions for a single input.
It can represent the index of a prediction, the key of a prediction in case of a dict,
Expand All @@ -85,10 +85,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
if stage != "predict":
return

# Validate the interval parameter
if self.interval not in ["step", "epoch"]:
raise ValueError("`interval` must be either 'step' or 'epoch'.")

# Ensure all distributed nodes write to the same directory
self.directory = trainer.strategy.broadcast(self.directory)
if trainer.is_global_zero:
Expand All @@ -102,43 +98,17 @@ def on_predict_batch_end(
self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int = 0
) -> None:
"""Callback method triggered at the end of each prediction batch/step."""
if self.interval != "step":
return

preds, ids = outputs["pred"], outputs["id"]

# Generate IDs if not provided
if ids is None:
# Fetch and decollate preds.
preds = decollate_batch(outputs["pred"], detach=True, pad=False)
# Fetch and decollate IDs if provided.
if outputs["id"] is not None:
ids = decollate_batch(outputs["id"], detach=True, pad=False)
# Generate IDs if not provided. An ID will be the index of the prediction.
else:
ids = list(range(self.last_index, self.last_index + len(preds)))
self.last_index += len(preds)

self._on_batch_or_epoch_end(preds, ids)

def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None:
"""Callback method triggered at the end of the prediction epoch."""
if self.interval != "epoch":
return
# Only one epoch when predicting, select its outputs.
preds, ids = outputs["pred"][0], outputs["id"][0]
# Remove the batch dimension since every pred is a single sample.
preds = [pred.squeeze(0) for pred in preds]
# Group predictions from all samples into a unified structure.
preds = group_tensors(preds)
# If no ids provided, assign default sequential ids based on the prediction order.
if ids[0] is None:
ids = list(range(len(preds)))
self._on_batch_or_epoch_end(preds, ids)

def _on_batch_or_epoch_end(self, preds, ids):
"""
Process each prediction at the end of either a batch or epoch and save in the defined format.
Args:
preds: Predicted tensors.
ids: Corresponding identifiers for the predictions.
"""
# Sanity check to ensure matched lengths for predictions and ids.
assert len(ids) == len(preds)
# Iterate over the predictions and save them.
for id, pred in zip(ids, preds):
# Convert predictions into a structured format suitable for writing.
parsed_pred = parse_data(pred)
Expand All @@ -148,7 +118,7 @@ def _on_batch_or_epoch_end(self, preds, ids):
# If multiple outputs, parsed_pred will contain multiple keys. For a single output, key will be None.
for multi_pred_id, tensor in parsed_pred.items():
# Save the prediction as per the designated format.
self.write(tensor.detach().cpu(), id, multi_pred_id, self.parsed_format[multi_pred_id])
self.write(tensor, id, multi_pred_id, format=self.parsed_format[multi_pred_id])

def add_writer(self, format: str, writer_function: Callable) -> None:
"""
Expand Down Expand Up @@ -181,31 +151,3 @@ def get_writer(self, format: str) -> Callable:
if format not in self._writers:
raise ValueError(f"Writer for format {format} not registered.")
return self._writers[format]


def parse_format(format: str, parsed_preds: Dict[str, Any]) -> Dict[str, str]:
"""
Parse the given format and align it with the structure of the predictions.
If the format is a single string, all predictions will be saved in this format. If the format has a structure
(like a dictionary), it needs to match the structure of the predictions.
Args:
format (str): The storage format for the predictions, either as a string or a structured format.
parsed_preds (Dict[str, Any]): Dictionary of parsed prediction data.
Returns:
Dict[str, str]: Dictionary of parsed format data corresponding to the prediction structure.
Raises:
ValueError: If the structure of the format does not align with the prediction structure.
"""
if isinstance(format, str):
# Assign the single format to all prediction keys.
parsed_format = {key: format for key in parsed_preds}
else:
# Ensure the structured format corresponds with the predictions' structure.
parsed_format = parse_data(format)
if not set(parsed_format) == set(parsed_preds):
raise ValueError("`format` structure does not match the prediction's structure.")
return parsed_format
24 changes: 12 additions & 12 deletions lighter/callbacks/writer/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ class LighterFileWriter(LighterBaseWriter):
Args:
directory (Union[str, Path]): The directory where the files should be written.
format (str): The format in which the files should be saved.
interval (str): Interval for writing, e.g., "epoch", "batch".
additional_writers (Optional[Dict[str, Callable]]): Additional custom writer functions.
"""

def __init__(
self, directory: Union[str, Path], format: str, interval: str, additional_writers: Optional[Dict[str, Callable]] = None
self, directory: Union[str, Path], format: str, additional_writers: Optional[Dict[str, Callable]] = None
) -> None:
# Predefined writers for different formats.
self._writers = {
Expand All @@ -40,7 +39,7 @@ def __init__(
"sitk_nifti": write_sitk_nifti,
}
# Initialize the base class.
super().__init__(directory, format, interval, additional_writers)
super().__init__(directory, format, additional_writers)

def write(self, tensor: torch.Tensor, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None:
"""
Expand All @@ -55,12 +54,8 @@ def write(self, tensor: torch.Tensor, id: Union[int, str], multi_pred_id: Option
multi_pred_id (Optional[Union[int, str]]): The secondary identifier, used if there are multiple predictions.
format (str): Format in which tensor should be written.
"""
# Determine the path for the file based on prediction count.
if multi_pred_id is not None:
path = self.directory / str(id) / str(multi_pred_id)
else:
path = self.directory / str(id)
# Ensure the directory exists.
# Determine the path for the file based on prediction count. The suffix must be added by the writer function.
path = self.directory / str(id) if multi_pred_id is None else self.directory / str(id) / str(multi_pred_id)
path.parent.mkdir(exist_ok=True, parents=True)
# Fetch the appropriate writer function for the format.
writer = self.get_writer(format)
Expand All @@ -73,10 +68,13 @@ def write_tensor(path, tensor):


def write_image(path, tensor):
torchvision.io.write_png(preprocess_image(tensor), path.with_suffix(".png"))
path = path.with_suffix(".png")
tensor = preprocess_image(tensor, add_batch_dim=True)
torchvision.io.write_png(tensor, path)


def write_video(path, tensor):
path = path.with_suffix(".mp4")
# Video tensor must be divisible by 2. Pad the height and width.
tensor = DivisiblePad(k=(0, 2, 2), mode="minimum")(tensor)
# Video tensor must be THWC. Permute CTHW -> THWC.
Expand All @@ -86,10 +84,12 @@ def write_video(path, tensor):
tensor = tensor.repeat(1, 1, 1, 3)
# Video tensor must be in the range [0, 1]. Scale to [0, 255].
tensor = (tensor * 255).to(torch.uint8)
torchvision.io.write_video(str(path.with_suffix(".mp4")), tensor, fps=24)
torchvision.io.write_video(str(path), tensor, fps=24)


def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None:
path = path.with_suffix(suffix)

if "sitk" not in OPTIONAL_IMPORTS:
OPTIONAL_IMPORTS["sitk"], sitk_available = optional_import("SimpleITK")
if not sitk_available:
Expand All @@ -98,7 +98,7 @@ def _write_sitk_image(path: str, tensor: torch.Tensor, suffix) -> None:
# Remove the channel dimension if it's equal to 1.
tensor = tensor.squeeze(0) if (tensor.dim() == 4 and tensor.shape[0] == 1) else tensor
sitk_image = OPTIONAL_IMPORTS["sitk"].GetImageFromArray(tensor.cpu().numpy())
OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path.with_suffix(".nrrd")), True)
OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path), useCompression=True)


def write_sitk_nrrd(path, tensor):
Expand Down
13 changes: 6 additions & 7 deletions lighter/callbacks/writer/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def __init__(
}

# Initialize the base class.
super().__init__(directory, format, "epoch", additional_writers)
super().__init__(directory, format, additional_writers)

# Create a dictionary to hold CSV records for each ID.
# Create a dictionary to hold CSV records for each ID. These are populated at each batch end
# by `self.on_predict_batch_end` defined in the base class using the `write` method below.
# Finally, the records are dumped to a CSV file at the end of the epoch by `self.on_predict_epoch_end`.
self.csv_records = {}

def write(self, tensor: Any, format: str, id: Union[int, str], multi_pred_id: Optional[Union[int, str]]) -> None:
def write(self, tensor: Any, id: Union[int, str], multi_pred_id: Optional[Union[int, str]], format: str) -> None:
"""
Write the tensor as a table record in the given format.
Expand Down Expand Up @@ -67,7 +69,7 @@ def write(self, tensor: Any, format: str, id: Union[int, str], multi_pred_id: Op
else:
self.csv_records[id][column] = record

def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None:
def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
"""
Callback method triggered at the end of the prediction epoch to dump the CSV table.
Expand All @@ -76,9 +78,6 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outpu
pl_module (LighterSystem): Lighter system instance.
outputs (List[Any]): List of predictions.
"""
# Call the parent class's method to handle additional end-of-epoch logic
super().on_predict_epoch_end(trainer, pl_module, outputs)

# Set the path where the CSV will be saved
csv_path = self.directory / "predictions.csv"

Expand Down
5 changes: 2 additions & 3 deletions projects/cifar10/experiments/monai_bundle_prototype.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ trainer:
max_samples: 10

- _target_: lighter.callbacks.LighterFileWriter
write_dir: "$@project + '/predictions' "
write_format: "tensor"
write_interval: "step" # "epoch"
directory: "$@project + '/predictions' "
format: "tensor"

system:
_target_: lighter.LighterSystem
Expand Down

0 comments on commit 2ea66bd

Please sign in to comment.