diff --git a/lighter/callbacks/__init__.py b/lighter/callbacks/__init__.py index 78d3fcf..360ba88 100644 --- a/lighter/callbacks/__init__.py +++ b/lighter/callbacks/__init__.py @@ -1,3 +1,4 @@ +from .freezer import LighterFreezer from .logger import LighterLogger from .writer.file import LighterFileWriter from .writer.table import LighterTableWriter diff --git a/lighter/callbacks/freezer.py b/lighter/callbacks/freezer.py index e4b54d5..b2aba88 100644 --- a/lighter/callbacks/freezer.py +++ b/lighter/callbacks/freezer.py @@ -69,7 +69,7 @@ def on_test_batch_start( self._on_batch_start(trainer, pl_module) def on_predict_batch_start( - self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: self._on_batch_start(trainer, pl_module) @@ -122,20 +122,23 @@ def _set_model_requires_grad(self, model: Union[Module, LighterSystem], requires # Leave the excluded-from-freezing parameters trainable. if self.except_names and name in self.except_names: param.requires_grad = True - elif self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with): + continue + if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with): param.requires_grad = True + continue # Freeze/unfreeze the specified parameters, based on the `requires_grad` argument. - elif self.names and name in self.names: + if self.names and name in self.names: param.requires_grad = requires_grad frozen_layers.append(name) - elif self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with): + continue + if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with): param.requires_grad = requires_grad frozen_layers.append(name) + continue # Otherwise, leave the parameter trainable. - else: - param.requires_grad = True + param.requires_grad = True self._frozen_state = not requires_grad # Log only when freezing the parameters. diff --git a/lighter/callbacks/logger.py b/lighter/callbacks/logger.py index 6c57ed2..1c34c55 100644 --- a/lighter/callbacks/logger.py +++ b/lighter/callbacks/logger.py @@ -1,18 +1,16 @@ from typing import Any, Dict, Union import itertools -import sys from datetime import datetime from pathlib import Path import torch from loguru import logger -from monai.utils.module import optional_import from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from lighter import LighterSystem -from lighter.callbacks.utils import get_lighter_mode, is_data_type_supported, parse_data, preprocess_image +from lighter.callbacks.utils import get_lighter_mode, preprocess_image from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS @@ -62,8 +60,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: stage (str): stage of the training process. Passed automatically by PyTorch Lightning. """ if trainer.logger is not None: - logger.error("When using LighterLogger, set Trainer(logger=None).") - sys.exit() + raise ValueError("When using LighterLogger, set Trainer(logger=None).") if not trainer.is_global_zero: return @@ -76,8 +73,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: # Tensorboard initialization. if self.tensorboard: - # Tensorboard is a part of PyTorch, no need to check if it is not available. - OPTIONAL_IMPORTS["tensorboard"], _ = optional_import("torch.utils.tensorboard") tensorboard_dir = self.log_dir / "tensorboard" tensorboard_dir.mkdir() self.tensorboard = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter(log_dir=tensorboard_dir) @@ -86,10 +81,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: # Wandb initialization. if self.wandb: - OPTIONAL_IMPORTS["wandb"], wandb_available = optional_import("wandb") - if not wandb_available: - logger.error("Weights & Biases not installed. To install it, run `pip install wandb`. Exiting.") - sys.exit() wandb_dir = self.log_dir / "wandb" wandb_dir.mkdir() self.wandb = OPTIONAL_IMPORTS["wandb"].init(project=self.project, dir=wandb_dir, config=self.config) @@ -165,6 +156,7 @@ def _log_image(self, name: str, image: torch.Tensor, global_step: int) -> None: image (torch.Tensor): image to be logged. global_step (int): current global step. """ + image = image.detach().cpu() if self.tensorboard: self.tensorboard.add_image(name, image, global_step=global_step) if self.wandb: @@ -179,7 +171,6 @@ def _log_histogram(self, name: str, tensor: torch.Tensor, global_step: int) -> N global_step (int): current global step. """ tensor = tensor.detach().cpu() - if self.tensorboard: self.tensorboard.add_histogram(name, tensor, global_step=global_step) if self.wandb: @@ -193,49 +184,44 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None: outputs (Dict): output dict from the model. trainer (Trainer): Trainer, passed automatically by PyTorch Lightning. """ - if not trainer.sanity_checking: - mode = get_lighter_mode(trainer.state.stage) - # Accumulate the loss. - if mode in ["train", "val"]: - self.loss[mode] += outputs["loss"].item() - # Logging frequency. Log only on rank 0. - if trainer.is_global_zero and self.global_step_counter[mode] % trainer.log_every_n_steps == 0: - # Get global step. - global_step = self._get_global_step(trainer) - - # Log loss. - if outputs["loss"] is not None: - self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step) - - # Log metrics. - if outputs["metrics"] is not None: - for name, metric in outputs["metrics"].items(): - self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step) - - # Log input, target, and pred. - for name in ["input", "target", "pred"]: - if self.log_types[name] is None: - continue - # Ensure data is of a valid type. - if not is_data_type_supported(outputs[name]): - raise ValueError( - f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], " - f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(outputs[name])}` is not supported." - ) - for identifier, item in parse_data(outputs[name]).items(): - item_name = f"{mode}/data/{name}" if identifier is None else f"{mode}/data/{name}_{identifier}" - self._log_by_type(item_name, item, self.log_types[name], global_step) - - # Log learning rate stats. Logs at step if a scheduler's interval is step-based. - if mode == "train": - lr_stats = self.lr_monitor.get_stats(trainer, "step") - for name, value in lr_stats.items(): - self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step) - - # Increment the step counters. - self.global_step_counter[mode] += 1 - if mode in ["train", "val"]: - self.epoch_step_counter[mode] += 1 + if trainer.sanity_checking: + return + + mode = get_lighter_mode(trainer.state.stage) + + # Accumulate the loss. + if mode in ["train", "val"]: + self.loss[mode] += outputs["loss"].item() + + # Log only on rank 0 and according to the `log_every_n_steps` parameter. Otherwise, only increment the step counters. + if not trainer.is_global_zero or self.global_step_counter[mode] % trainer.log_every_n_steps != 0: + self._increment_step_counters(mode) + return + + global_step = self._get_global_step(trainer) + + # Loss. + if outputs["loss"] is not None: + self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step) + + # Metrics. + if outputs["metrics"] is not None: + for name, metric in outputs["metrics"].items(): + self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step) + + # Input, target, and pred. + for name in ["input", "target", "pred"]: + if self.log_types[name] is not None: + self._log_by_type(f"{mode}/data/{name}", outputs[name], self.log_types[name], global_step) + + # LR info. Logs at step if a scheduler's interval is step-based. + if mode == "train": + lr_stats = self.lr_monitor.get_stats(trainer, "step") + for name, value in lr_stats.items(): + self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step) + + # Increment the step counters. + self._increment_step_counters(mode) def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: """Performs logging at the end of an epoch. Logs the epoch number, the loss, and the metrics. @@ -249,46 +235,44 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: mode = get_lighter_mode(trainer.state.stage) loss, metrics = None, None - # Loss + # Get the accumulated loss over the epoch and processes. if mode in ["train", "val"]: - # Get the accumulated loss. loss = self.loss[mode] - # Reduce the loss and average it on each rank. loss = trainer.strategy.reduce(loss, reduce_op="mean") - # Divide the accumulated loss by the number of steps in the epoch. loss /= self.epoch_step_counter[mode] - # Metrics # Get the torchmetrics. - metric_collection = pl_module.metrics[mode] + # TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203 + metric_collection = pl_module.metrics["_" + mode] if metric_collection is not None: # Compute the epoch metrics. metrics = metric_collection.compute() # Reset the metrics for the next epoch. metric_collection.reset() - # Log. Only on rank 0. - if trainer.is_global_zero: - # Get global step. - global_step = self._get_global_step(trainer) + # Log only on rank 0. + if not trainer.is_global_zero: + return - # Log epoch number. - self._log_scalar("epoch", trainer.current_epoch, global_step) + global_step = self._get_global_step(trainer) - # Log loss. - if loss is not None: - self._log_scalar(f"{mode}/loss/epoch", loss, global_step) + # Epoch number. + self._log_scalar("epoch", trainer.current_epoch, global_step) - # Log metrics. - if metrics is not None: - for name, metric in metrics.items(): - self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step) + # Loss. + if loss is not None: + self._log_scalar(f"{mode}/loss/epoch", loss, global_step) - # Log learning rate stats. Logs at epoch if a scheduler's interval is epoch-based, or if no scheduler is used. - if mode == "train": - lr_stats = self.lr_monitor.get_stats(trainer, "epoch") - for name, value in lr_stats.items(): - self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step) + # Metrics. + if metrics is not None: + for name, metric in metrics.items(): + self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step) + + # LR info. Logged at epoch if the scheduler's interval is epoch-based, or if no scheduler is used. + if mode == "train": + lr_stats = self.lr_monitor.get_stats(trainer, "epoch") + for name, value in lr_stats.items(): + self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step) def _get_global_step(self, trainer: Trainer) -> int: """Return the global step for the current mode. Note that when Trainer @@ -309,6 +293,16 @@ def _get_global_step(self, trainer: Trainer) -> int: return self.global_step_counter["train"] return self.global_step_counter[mode] + def _increment_step_counters(self, mode: str) -> None: + """Increment the global step and epoch step counters for the specified mode. + + Args: + mode (str): mode to increment the global step counter for. + """ + self.global_step_counter[mode] += 1 + if mode in ["train", "val"]: + self.epoch_step_counter[mode] += 1 + def on_train_epoch_start(self, trainer: Trainer, pl_module: LighterSystem) -> None: # Reset the loss and the epoch step counter for the next epoch. self.loss["train"] = 0 diff --git a/lighter/callbacks/utils.py b/lighter/callbacks/utils.py index 757bce6..bdadb93 100644 --- a/lighter/callbacks/utils.py +++ b/lighter/callbacks/utils.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - import torch import torchvision @@ -17,117 +15,6 @@ def get_lighter_mode(lightning_stage: str) -> str: return lightning_to_lighter[lightning_stage] -def is_data_type_supported(data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]) -> bool: - """ - Check the input data recursively for its type. Valid data types are: - - torch.Tensor - - List[torch.Tensor] - - Tuple[torch.Tensor] - - Dict[str, torch.Tensor] - - Dict[str, List[torch.Tensor]] - - Dict[str, Tuple[torch.Tensor]] - - Nested combinations of the above - - Args: - data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to check. - - Returns: - bool: True if the data type is supported, False otherwise. - """ - if isinstance(data, dict): - is_valid = all(is_data_type_supported(elem) for elem in data.values()) - elif isinstance(data, (list, tuple)): - is_valid = all(is_data_type_supported(elem) for elem in data) - elif isinstance(data, torch.Tensor): - is_valid = True - else: - is_valid = False - return is_valid - - -def parse_data( - data: Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]], prefix: Optional[str] = None -) -> Dict[Optional[str], Any]: - """ - Parse the input data recursively, handling nested dictionaries, lists, and tuples. - - This function will recursively parse the input data, unpacking nested dictionaries, lists, and tuples. The result - will be a dictionary where each key is a unique identifier reflecting the data's original structure (dict keys - or list/tuple positions) and each value is a non-container data type from the input data. - - Args: - data (Union[Any, List[Any], Dict[str, Union[Any, List[Any], Tuple[Any]]]]): Input data to parse. - prefix (Optional[str]): Current prefix for keys in the result dictionary. Defaults to None. - - Returns: - Dict[Optional[str], Any]: A dictionary where key is either a string identifier or `None`, and value is the parsed output. - - Example: - input_data = { - "a": [1, 2], - "b": {"c": (3, 4), "d": 5} - } - output_data = parse_data(input_data) - # Output: - # { - # 'a_0': 1, - # 'a_1': 2, - # 'b_c_0': 3, - # 'b_c_1': 4, - # 'b_d': 5 - # } - """ - result = {} - if isinstance(data, dict): - for key, value in data.items(): - # Recursively parse the value with an updated prefix - sub_result = parse_data(value, prefix=f"{prefix}_{key}" if prefix else key) - result.update(sub_result) - elif isinstance(data, (list, tuple)): - for idx, element in enumerate(data): - # Recursively parse the element with an updated prefix - sub_result = parse_data(element, prefix=f"{prefix}_{idx}" if prefix else str(idx)) - result.update(sub_result) - else: - # Assign the value to the result dictionary using the current prefix as its key - result[prefix] = data - return result - - -def gather_tensors( - inputs: Union[List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]] -) -> Union[List, Dict]: - """Recursively gather tensors. Tensors can be standalone or inside of other data structures (list/tuple/dict). - An input list of tensors is returned as-is. Given an input list of data structures with tensors, this function - will gather all tensors into a list and save it under a single data structure. Assumes that all elements of - the input list have the same type and structure. - - Args: - inputs (List[Union[torch.Tensor, List, Tuple, Dict]], Tuple[Union[torch.Tensor, List, Tuple, Dict]]): - They can be: - - List/Tuples of Dictionaries, each containing tensors to be gathered by their key. - - List/Tuples of Lists/tuples, each containing tensors to be gathered by their position. - - List/Tuples of Tensors, returned as-is. - - Nested versions of the above. - The input data structure must be the same for all elements of the list. They can be arbitrarily nested. - - Returns: - Union[List, Dict]: The gathered tensors. - """ - # List of dicts. - if isinstance(inputs[0], dict): - keys = inputs[0].keys() - return {key: gather_tensors([input[key] for input in inputs]) for key in keys} - # List of lists or tuples. - elif isinstance(inputs[0], (list, tuple)): - return [gather_tensors([input[idx] for input in inputs]) for idx in range(len(inputs[0]))] - # List of tensors. - elif isinstance(inputs[0], torch.Tensor): - return inputs - else: - raise TypeError(f"Type `{type(inputs[0])}` not supported.") - - def preprocess_image(image: torch.Tensor) -> torch.Tensor: """Preprocess the image before logging it. If it is a batch of multiple images, it will create a grid image of them. In case of 3D, a single image is displayed @@ -138,7 +25,6 @@ def preprocess_image(image: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: image ready for logging. """ - image = image.detach().cpu() # If 3D (BCDHW), concat the images vertically and horizontally. if image.ndim == 5: shape = image.shape diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index ab0ee79..410fb87 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -1,146 +1,115 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Union -import itertools -import sys from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path import torch -from loguru import logger from pytorch_lightning import Callback, Trainer from lighter import LighterSystem -from lighter.callbacks.utils import gather_tensors, parse_data class LighterBaseWriter(ABC, Callback): - """Base class for a Writer. Override `self.write()` to define how a prediction should be saved. - `LighterBaseWriter` sets up the write directory, and defines `on_predict_batch_end` and - `on_predict_epoch_end`. `write_interval` specifies which of the two should the writer call. + """ + Base class for defining custom Writer. It provides the structure to save predictions in various formats. + + Subclasses should implement: + 1) `self.writers` attribute to specify the supported formats and their corresponding writer functions. + 2) `self.write()` method to specify the saving strategy for a prediction. Args: - write_dir (str): the Writer will create a directory inside of `write_dir` with date - and time as its name and store the predictions there. - write_format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): - type in which the predictions will be stored. Passed automatically to the `write()` - abstract method and can be used to support writing different types. Should the Writer - support only one type, this argument can be removed from the overriden `__init__()`'s - arguments and set `self.write_format = None`. - write_interval (str, optional): whether to write on each step or at the end of the prediction epoch. - Defaults to "step". + directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. """ - def __init__( - self, - write_dir: str, - write_format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - write_interval: str = "step", - ) -> None: - self.write_dir = Path(write_dir) / datetime.now().strftime("%Y%m%d_%H%M%S") - self.write_format = write_format - self.write_interval = write_interval + def __init__(self, directory: str, writer: Union[str, Callable]) -> None: + """ + Initialize the LighterBaseWriter. - self.parsed_write_format = None + Args: + directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. + """ + # Create a unique directory using the current date and time + self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S") + + # Check if the writer is a string and if it exists in the writers dictionary + if isinstance(writer, str): + if writer not in self.writers: + raise ValueError(f"Writer for format {writer} does not exist. Available writers: {self.writers.keys()}.") + self.writer = self.writers[writer] + else: + # If the writer is not a string, it is assumed to be a callable function + self.writer = writer + + # Prediction counter. Used when IDs are not provided. Initialized in `self.setup()` based on the DDP rank. + self._pred_counter = None + + @property + @abstractmethod + def writers(self) -> Dict[str, Callable]: + """ + Property to define the default writer functions. + """ @abstractmethod - def write( - self, - idx: int, - identifier: Optional[str], - tensor: torch.Tensor, - write_format: Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]], - ): - """This method must be overridden to specify how a tensor should be saved. If the Writer - supports multiple types of saving, handle the `write_format` argument with an if-else statement. - - If the Writer only supports one type, remove `write_format` from the overridden - `__init__()` method and set `self.write_format=None`. - - The `idx` and `identifier` arguments can be used to specify the name of the file - or the row and column of a table for the prediction. - - Parameters: - idx (int): The index of the prediction. - identifier (Optional[str]): The identifier of the prediction. It will be `None` if there's - only one prediction, an index if the prediction is a list of predictions, a key if it's - a dict of predictions, and a key_index if it's a dict of list of predictions. - tensor (torch.Tensor): The predicted tensor. - write_format (Optional[Union[str, List[str], Dict[str, str], Dict[str, List[str]]]]): - Specifies how to write the predictions. If it's a single string value, the predictions - will be saved under that type regardless of whether they are single- or multi-output - predictions. To write different outputs in the multi-output predictions using different - methods, use the appropriate format for `write_format`. + def write(self, tensor: torch.Tensor, id: int) -> None: + """ + Method to define how a tensor should be saved. The input tensor will be a single tensor without + the batch dimension. If the batch dimension is needed, apply `tensor.unsqueeze(0)` before saving, + either in this method or in the particular writer function. + + For each supported format, there should be a corresponding writer function registered in `self.writers` + A specific writer function can be retrieved using `self.get_writer(self.format)`. + + Args: + tensor (torch.Tensor): Tensor to be saved. It will be a single tensor without the batch dimension. + id (int): Identifier for the tensor, can be used for naming or indexing. """ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: + """ + Callback function to set up necessary prerequisites: prediction count and prediction directory. + When executing in a distributed environment, it ensures that: + 1. Each distributed node initializes a prediction count based on its rank. + 2. All distributed nodes write predictions to the same directory. + 3. The directory is accessible to all nodes, i.e., all nodes share the same storage. + """ if stage != "predict": return - if self.write_interval not in ["step", "epoch"]: - logger.error("`write_interval` must be either 'step' or 'epoch'.") - sys.exit() + # Initialize the prediction count with the rank of the current process + self._pred_counter = torch.distributed.get_rank() if trainer.world_size > 1 else 0 - # Broadcast the `write_dir` so that all ranks write their predictions there. - self.write_dir = trainer.strategy.broadcast(self.write_dir) - # Let rank 0 create the `write_dir`. + # Ensure all distributed nodes write to the same directory + self.directory = trainer.strategy.broadcast(self.directory, src=0) if trainer.is_global_zero: - self.write_dir.mkdir(parents=True) - # If `write_dir` does not exist, the ranks are not on the same storage. - if not self.write_dir.exists(): - logger.error( - f"Rank {trainer.global_rank} is not on the same storage as rank 0." - "Please run the prediction only on nodes that are on the same storage." + self.directory.mkdir(parents=True) + # Wait for rank 0 to create the directory + trainer.strategy.barrier() + + # Ensure all distributed nodes have access to the directory + if not self.directory.exists(): + raise RuntimeError( + f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access." ) - sys.exit() def on_predict_batch_end( - self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: - if self.write_interval != "step": - return - indices = trainer.predict_loop.epoch_loop.current_batch_indices - self._on_batch_or_epoch_end(outputs, indices) + """ + Callback method executed at the end of each prediction batch/step. + If the IDs are not provided, it generates global unique IDs based on the prediction count. + Finally, it writes the predictions using the specified writer. + """ - def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - if self.write_interval != "epoch": - return - # Only one epoch when predicting, index the lists of outputs and batch indices accordingly. - indices = trainer.predict_loop.epoch_batch_indices[0] - outputs = outputs[0] - # Flatten so that each output sample corresponds to its index. - indices = list(itertools.chain(*indices)) - # Remove the batch dimension since every output is a single sample. - outputs = [output.squeeze(0) for output in outputs] - # Gather the output tensors from all samples into a single structure rather than having one structures for each sample. - outputs = gather_tensors(outputs) - self._on_batch_or_epoch_end(outputs, indices) - - def _on_batch_or_epoch_end(self, outputs, indices): - """Iterate through each output and save it in the specified format. The outputs and indices are automatically - split individually by PyTorch Lightning.""" - # Sanity check. Should never happen. If it does, something is wrong with the Trainer. - assert len(indices) == len(outputs) - # `idx` is the index of the input sample, `output` is the output of the model for that sample. - for idx, output in zip(indices, outputs): - # Parse the outputs into a structure ready for writing. - parsed_output = parse_data(output) - # Parse `self.write_format`. If multi-value, check if its structure matches `parsed_output`'s structure. - if self.parsed_write_format is None: - self.parsed_write_format = self._parse_write_format(self.write_format, parsed_output) - # Iterate through each prediction for the `idx`-th input sample. - for identifier, tensor in parsed_output.items(): - # Save the prediction in the specified format. - self.write(idx, identifier, tensor.detach().cpu(), self.parsed_write_format[identifier]) - - def _parse_write_format(self, write_format, parsed_outputs: Dict[str, Any]): - # If `write_format` is a string (single value), all outputs will be saved in that specified format. - if isinstance(write_format, str): - parsed_write_format = {key: write_format for key in parsed_outputs} - # Otherwise, `write_format` needs to match the structure of the outputs in order to assign each tensor its type. - else: - parsed_write_format = parse_data(write_format) - if not set(parsed_write_format) == set(parsed_outputs): - logger.error("`write_format` structure does not match the prediction's structure.") - sys.exit() - return parsed_write_format + # If the IDs are not provided, generate global unique IDs based on the prediction count. DDP supported. + if outputs["id"] is None: + batch_size = len(outputs["pred"]) + world_size = trainer.world_size + outputs["id"] = list(range(self._pred_counter, self._pred_counter + batch_size * world_size, world_size)) + self._pred_counter += batch_size * world_size + + for id, pred in zip(outputs["id"], outputs["pred"]): + self.write(tensor=pred, id=id) diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index 7536a0e..c2a0d57 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -1,10 +1,13 @@ -import sys +from typing import Callable, Dict, Union +from functools import partial +from pathlib import Path + +import monai import torch import torchvision -from loguru import logger +from monai.data import metatensor_to_itk_image from monai.transforms import DivisiblePad -from monai.utils.module import optional_import from lighter.callbacks.utils import preprocess_image from lighter.callbacks.writer.base import LighterBaseWriter @@ -12,61 +15,77 @@ class LighterFileWriter(LighterBaseWriter): - def write(self, idx, identifier, tensor, write_format): - filename = f"{write_format}" if identifier is None else f"{identifier}_{write_format}" - write_dir = self.write_dir / str(idx) - write_dir.mkdir() - - if write_format is None: - pass - - # Tensor - elif write_format == "tensor": - path = write_dir / f"{filename}.pt" - torch.save(tensor, path) - - # Image - elif write_format == "image": - path = write_dir / f"{filename}.png" - torchvision.io.write_png(preprocess_image(tensor), path) - - # Video - elif write_format == "video": - path = write_dir / f"{filename}.mp4" - # Video tensor must be divisible by 2. Pad the height and width. - tensor = DivisiblePad(k=(0, 2, 2), mode="minimum")(tensor) - # Video tensor must be THWC. Permute CTHW -> THWC. - tensor = tensor.permute(1, 2, 3, 0) - # Video tensor must have 3 channels (RGB). Repeat the channel dim to convert grayscale to RGB. - if tensor.shape[-1] == 1: - tensor = tensor.repeat(1, 1, 1, 3) - # Video tensor must be in the range [0, 1]. Scale to [0, 255]. - tensor = (tensor * 255).to(torch.uint8) - torchvision.io.write_video(str(path), tensor, fps=24) - - # Scalar - elif write_format == "scalar": - raise NotImplementedError - - # Audio - elif write_format == "audio": - raise NotImplementedError - - else: - logger.error(f"`write_format` '{write_format}' not supported.") - sys.exit() - - -def write_sitk_image(path: str, tensor: torch.Tensor) -> None: - """Write a SimpleITK image to disk. + """ + Writer for writing predictions to files. Supports multiple formats, and + additional custom formats can be added either through `additional_writers` + argument at initialization, or by calling `add_writer` method after initialization. Args: - path (str): path to write the image. - tensor (torch.Tensor): tensor to write. + directory (Union[str, Path]): The directory where the files should be written. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. + Available writers: "tensor", "image", "video", "itk_nrrd", "itk_seg_nrrd", "itk_nifti". """ - if "sitk" not in OPTIONAL_IMPORTS: - OPTIONAL_IMPORTS["sitk"], sitk_available = optional_import("SimpleITK") - if not sitk_available: - raise ModuleNotFoundError("SimpleITK not installed. To install it, run `pip install SimpleITK`. Exiting.") - sitk_image = OPTIONAL_IMPORTS["sitk"].GetImageFromArray(tensor.cpu().numpy()) - OPTIONAL_IMPORTS["sitk"].WriteImage(sitk_image, str(path), True) + + def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None: + super().__init__(directory, writer) + + @property + def writers(self) -> Dict[str, Callable]: + return { + "tensor": write_tensor, + "image": write_image, + "video": write_video, + "itk_nrrd": partial(write_itk_image, suffix=".nrrd"), + "itk_seg_nrrd": partial(write_itk_image, suffix=".seg.nrrd"), + "itk_nifti": partial(write_itk_image, suffix=".nii.gz"), + } + + def write(self, tensor: torch.Tensor, id: Union[int, str]) -> None: + """ + Write the tensor to the specified path in the given format. + + Args: + tensor (Tensor): The tensor to be written. + id (Union[int, str]): The identifier for naming. + format (str): Format in which tensor should be written. + """ + # Determine the path for the file based on prediction count. The suffix must be added by the writer function. + path = self.directory / str(id) + path.parent.mkdir(exist_ok=True, parents=True) + # Write the tensor to the file. + self.writer(path, tensor) + + +def write_tensor(path, tensor): + torch.save(tensor, path.with_suffix(".pt")) + + +def write_image(path, tensor): + path = path.with_suffix(".png") + tensor = preprocess_image(tensor) + torchvision.io.write_png(tensor, path) + + +def write_video(path, tensor): + path = path.with_suffix(".mp4") + # Video tensor must be divisible by 2. Pad the height and width. + tensor = DivisiblePad(k=(0, 2, 2), mode="minimum")(tensor) + # Video tensor must be THWC. Permute CTHW -> THWC. + tensor = tensor.permute(1, 2, 3, 0) + # Video tensor must have 3 channels (RGB). Repeat the channel dim to convert grayscale to RGB. + if tensor.shape[-1] == 1: + tensor = tensor.repeat(1, 1, 1, 3) + # Video tensor must be in the range [0, 1]. Scale to [0, 255]. + tensor = (tensor * 255).to(torch.uint8) + torchvision.io.write_video(str(path), tensor, fps=24) + + +def write_itk_image(path: str, tensor: torch.Tensor, suffix) -> None: + path = path.with_suffix(suffix) + + # TODO: Remove this code when fixed https://github.com/Project-MONAI/MONAI/issues/6985 + if tensor.meta["space"] == "RAS": + tensor.affine = monai.data.utils.orientation_ras_lps(tensor.affine) + + itk_image = metatensor_to_itk_image(tensor, channel_dim=0, dtype=tensor.dtype) + OPTIONAL_IMPORTS["itk"].imwrite(itk_image, str(path), True) diff --git a/lighter/callbacks/writer/table.py b/lighter/callbacks/writer/table.py index 5e490a6..1b7efcc 100644 --- a/lighter/callbacks/writer/table.py +++ b/lighter/callbacks/writer/table.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, List, Union +from typing import Any, Callable, Dict, Union import itertools -import sys +from pathlib import Path import pandas as pd -from loguru import logger +import torch from pytorch_lightning import Trainer from lighter import LighterSystem @@ -12,46 +12,61 @@ class LighterTableWriter(LighterBaseWriter): - def __init__(self, write_dir: str, write_format: Union[str, List[str], Dict[str, str], Dict[str, List[str]]]) -> None: - super().__init__(write_dir, write_format, write_interval="epoch") + """ + Writer for saving predictions in a table format. + + Args: + directory (Path): The directory where the CSV will be saved. + writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. + Available writers: "tensor". + """ + + def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None: + super().__init__(directory, writer) self.csv_records = {} - def write(self, idx, identifier, tensor, write_format): - # Column name will be set to 'pred' if the identifier is None. - column = "pred" if identifier is None else identifier - - if write_format is None: - record = None - elif write_format == "tensor": - record = tensor.tolist() - elif write_format == "scalar": - raise NotImplementedError - else: - logger.error(f"`write_format` '{write_format}' not supported.") - sys.exit() - - if idx not in self.csv_records: - self.csv_records[idx] = {column: record} - else: - self.csv_records[idx][column] = record - - def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outputs: List[Any]) -> None: - super().on_predict_epoch_end(trainer, pl_module, outputs) - - csv_path = self.write_dir / "predictions.csv" - logger.info(f"Saving the predictions to {csv_path}") - - # Sort the dict of dicts by key and turn it into a list of dicts. - self.csv_records = [self.csv_records[key] for key in sorted(self.csv_records)] - # Gather the records from all ranks when in DDP. + @property + def writers(self) -> Dict[str, Callable]: + return { + "tensor": lambda tensor: tensor.tolist(), + } + + def write(self, tensor: Any, id: Union[int, str]) -> None: + """ + Write the tensor as a table record using the specified writer. + + Args: + tensor (Any): The tensor to be written. + id (Union[int, str]): The identifier used as the key for the record. + """ + column = "pred" + record = self.writer(tensor) + + self.csv_records.setdefault(id, {})[column] = record + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None: + """ + Callback invoked at the end of the prediction epoch to save predictions to a CSV file. + + This method is responsible for organizing prediction records and saving them as a CSV file. + If training was done in a distributed setting, it gathers predictions from all processes + and then saves them from the rank 0 process. + """ + csv_path = self.directory / "predictions.csv" + + # Sort the records by ID and convert the dictionary to a list + self.csv_records = [self.csv_records[id] for id in sorted(self.csv_records)] + + # If in distributed data parallel mode, gather records from all processes to rank 0. if trainer.world_size > 1: - # Since `all_gather` supports tensors only, mimic the behavior using `broadcast`. - ddp_csv_records = [self.csv_records] * trainer.world_size - for rank in range(trainer.world_size): - # Broadcast the records from the current rank and save it at its designated position. - ddp_csv_records[rank] = trainer.strategy.broadcast(ddp_csv_records[rank], src=rank) - # Combine the records from all ranks. List of lists of dicts -> list of dicts. - self.csv_records = list(itertools.chain(*ddp_csv_records)) - - # Create a dataframe and save it. - pd.DataFrame(self.csv_records).to_csv(csv_path) + # Create a list to hold the records from each process. Used on rank 0 only. + gather_csv_records = [None] * trainer.world_size if trainer.is_global_zero else None + # Each process sends its records to rank 0, which stores them in the `gather_csv_records`. + torch.distributed.gather_object(self.csv_records, gather_csv_records, dst=0) + # Concatenate the gathered records + if trainer.is_global_zero: + self.csv_records = list(itertools.chain(*gather_csv_records)) + + # Save the records to a CSV file + if trainer.is_global_zero: + pd.DataFrame(self.csv_records).to_csv(csv_path) diff --git a/lighter/system.py b/lighter/system.py index e5f6676..4373302 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -1,57 +1,51 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import sys from functools import partial import pytorch_lightning as pl import torch from loguru import logger -from torch.nn import Module +from torch.nn import Module, ModuleDict from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader, Dataset, Sampler from torchmetrics import Metric, MetricCollection from lighter.utils.collate import collate_replace_corrupted -from lighter.utils.misc import apply_fns, ensure_dict_schema, ensure_list, get_name, hasarg +from lighter.utils.misc import apply_fns, ensure_dict_schema, get_name, hasarg class LighterSystem(pl.LightningModule): """_summary_ Args: - model (Module): the model. - batch_size (int): batch size. - drop_last_batch (bool, optional): whether the last batch in the dataloader - should be dropped. Defaults to False. - num_workers (int, optional): number of dataloader workers. Defaults to 0. - pin_memory (bool, optional): whether to pin the dataloaders memory. Defaults to True. - optimizer (Optional[Union[Optimizer, List[Optimizer]]], optional): - a single or a list of optimizers. Defaults to None. - scheduler (Optional[Union[Callable, List[Callable]]], optional): - a single or a list of schedulers. Defaults to None. - criterion (Optional[Callable], optional): - criterion/loss function. Defaults to None. - datasets (Optional[Dict[str, Optional[Dataset]]], optional): - datasets for train, val, test, and predict. Supports Defaults to None. - samplers (Optional[Dict[str, Optional[Sampler]]], optional): - samplers for train, val, test, and predict. Defaults to None. - collate_fns (Optional[Dict[str, Optional[Callable]]], optional): - collate functions for train, val, test, and predict. Defaults to None. - metrics (Optional[Dict[str, Optional[Union[Metric, List[Metric]]]]], optional): - metrics for train, val, and test. Supports a single metric or a list of metrics, - implemented using `torchmetrics`. Defaults to None. - postprocessing (Optional[Dict[str, Optional[Callable]]], optional): + model (Module): The model. + batch_size (int): Batch size. + drop_last_batch (bool, optional): Whether the last batch in the dataloader should be dropped. Defaults to False. + num_workers (int, optional): Number of dataloader workers. Defaults to 0. + pin_memory (bool, optional): Whether to pin the dataloaders memory. Defaults to True. + optimizer (Optimizer, optional): Optimizers. Defaults to None. + scheduler (LRScheduler, optional): Learning rate scheduler. Defaults to None. + criterion (Callable, optional): Criterion/loss function. Defaults to None. + datasets (Dict[str, Dataset], optional): Datasets for train, val, test, and predict. Defaults to None. + samplers (Dict[str, Sampler], optional): Samplers for train, val, test, and predict. Defaults to None. + collate_fns (Dict[str, Union[Callable, List[Callable]]], optional): + Collate functions for train, val, test, and predict. Defaults to None. + metrics (Dict[str, Union[Metric, List[Metric], Dict[str, Metric]]], optional): + Metrics for train, val, and test. Supports a single metric or a list/dict of `torchmetrics` metrics. + Defaults to None. + postprocessing (Dict[str, Union[Callable, List[Callable]]], optional): Postprocessing functions for input, target, and pred, for three stages - criterion, metrics, and logging. The postprocessing is done before each stage - for example, criterion postprocessing will be done prior to loss calculation. Note that the postprocessing of a latter stage stacks on top of the previous one(s) - for example, the logging postprocessing will be done on the data that has been postprocessed for the criterion and metrics earlier. Defaults to None. - inferer (Optional[Callable], optional): the inferer must be a class with a `__call__` - method that accepts two arguments - the input to infer over, and the model itself. - Used in 'val', 'test', and 'predict' mode, but not in 'train'. Typically, an inferer - is a sliding window or a patch-based inferer that will infer over the smaller parts of - the input, combine them, and return a single output. The inferers provided by MONAI - cover most of such cases (https://docs.monai.io/en/stable/inferers.html). Defaults to None. + inferer (Callable, optional): The inferer must be a class with a `__call__` method that accepts two + arguments - the input to infer over, and the model itself. Used in 'val', 'test', and 'predict' + mode, but not in 'train'. Typically, an inferer is a sliding window or a patch-based inferer + that will infer over the smaller parts of the input, combine them, and return a single output. + The inferers provided by MONAI cover most of such cases (https://docs.monai.io/en/stable/inferers.html). + Defaults to None. """ def __init__( @@ -61,14 +55,14 @@ def __init__( drop_last_batch: bool = False, num_workers: int = 0, pin_memory: bool = True, - optimizer: Optional[Union[Optimizer, List[Optimizer]]] = None, - scheduler: Optional[Union[Callable, List[Callable]]] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[LRScheduler] = None, criterion: Optional[Callable] = None, - datasets: Optional[Dict[str, Optional[Dataset]]] = None, - samplers: Optional[Dict[str, Optional[Sampler]]] = None, - collate_fns: Optional[Dict[str, Optional[Callable]]] = None, - metrics: Optional[Dict[str, Optional[Union[Metric, List[Metric]]]]] = None, - postprocessing: Optional[Dict[str, Optional[Callable]]] = None, + datasets: Dict[str, Dataset] = None, + samplers: Dict[str, Sampler] = None, + collate_fns: Dict[str, Union[Callable, List[Callable]]] = None, + metrics: Dict[str, Union[Metric, List[Metric], Dict[str, Metric]]] = None, + postprocessing: Dict[str, Union[Callable, List[Callable]]] = None, inferer: Optional[Callable] = None, ) -> None: super().__init__() @@ -82,33 +76,23 @@ def __init__( # Criterion, optimizer, and scheduler self.criterion = criterion - self.optimizer = ensure_list(optimizer) - self.scheduler = ensure_list(scheduler) + self.optimizer = optimizer + self.scheduler = scheduler # DataLoader specifics self.num_workers = num_workers self.pin_memory = pin_memory # Datasets, samplers, and collate functions - schema = {"train": None, "val": None, "test": None, "predict": None} - self.datasets = ensure_dict_schema(datasets, schema) - self.samplers = ensure_dict_schema(samplers, schema) - self.collate_fns = ensure_dict_schema(collate_fns, schema) + self.datasets = self._init_datasets(datasets) + self.samplers = self._init_samplers(samplers) + self.collate_fns = self._init_collate_fns(collate_fns) # Metrics - self.metrics = ensure_dict_schema(metrics, schema={"train": None, "val": None, "test": None}) - self.metrics = {mode: MetricCollection(ensure_list(metric)) for mode, metric in self.metrics.items()} - # Register the metrics to allow the LightningModule to automatically move them to the correct device. - # Currently, a workaround is needed because of https://github.com/pytorch/pytorch/issues/71203. - # Once it's fixed, we can set `self.metrics = ModuleDict(self.metrics)` directly. - for mode, mode_metrics in self.metrics.items(): - setattr(self, f"{mode}_metric", mode_metrics) - self.metrics[mode] = getattr(self, f"{mode}_metric") + self.metrics = self._init_metrics(metrics) # Postprocessing - schema = {"input": None, "target": None, "pred": None} - schema = {"criterion": schema, "metrics": schema, "logging": schema} - self.postprocessing = ensure_dict_schema(postprocessing, schema) + self.postprocessing = self._init_postprocessing(postprocessing) # Inferer for val, test, and predict self.inferer = inferer @@ -155,27 +139,34 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un For predict step, it returns pred only. """ - # Ensure that the batch is a list, a tuple, or a dict. - if not isinstance(batch, (list, tuple, dict)): + # Batch type check: + # - Dict: must contain "input" and "target" keys, and optionally "id" key. + if isinstance(batch, dict): + if set(batch.keys()) not in [{"input", "target"}, {"input", "target", "id"}]: + raise ValueError( + "A batch dict must have 'input', 'target', and, " + f"optionally 'id', as keys, but found {list(batch.keys())}" + ) + batch["id"] = None if "id" not in batch else batch["id"] + # - List/tuple: must contain two elements - input and target. After the check, convert it to dict. + elif isinstance(batch, (list, tuple)): + if len(batch) != 2: + raise ValueError( + f"A batch must consist of 2 elements - input and target. However, {len(batch)} " + "elements were found. Note: if target does not exist, return `None` as target." + ) + batch = {"input": batch[0], "target": batch[1], "id": None} + # - Other types are not allowed. + else: raise TypeError( "A batch must be a list, a tuple, or a dict." - "A batch dict must have 'input' and 'target' as keys." + "A batch dict must have 'input' and 'target' keys, and optionally 'id'." "A batch list or a tuple must have 2 elements - input and target." "If target does not exist, return `None` as target." ) - # Ensure that a dict batch has input and target keys exclusively. - if isinstance(batch, dict) and set(batch.keys()) != {"input", "target"}: - raise ValueError("A batch must be a dict with 'input' and 'target' as keys.") - # Ensure that a list/tuple batch has 2 elements (input and target). - if len(batch) == 1: - raise ValueError( - "A batch must consist of 2 elements - input and target. If target does not exist, return `None` as target." - ) - if len(batch) > 2: - raise ValueError(f"A batch must consist of 2 elements - input and target, but found {len(batch)} elements.") - # Split the batch into input and target. - input, target = batch if not isinstance(batch, dict) else (batch["input"], batch["target"]) + # Split the batch into input, target, and id. + input, target, id = batch["input"], batch["target"], batch["id"] # Forward if self.inferer and mode in ["val", "test", "predict"]: @@ -191,29 +182,34 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un # Calculate the loss. loss = self._calculate_loss(pred, target) if mode in ["train", "val"] else None # Log the loss for monitoring purposes. - self.log( - "loss" if mode == "train" else f"{mode}_loss", - loss, - on_step=True, - on_epoch=True, - sync_dist=True, - logger=False, - batch_size=self.batch_size, - ) + if loss is not None: + self.log( + "loss" if mode == "train" else f"{mode}_loss", + loss, + on_step=True, + on_epoch=True, + sync_dist=True, + logger=False, + batch_size=self.batch_size, + ) # Log and return the results. if mode == "predict": - return pred + # Pred postprocessing for logging or writing. + pred = apply_fns(pred, self.postprocessing["logging"]["pred"]) + return {"pred": pred, "id": id} else: # Data postprocessing for metrics input = apply_fns(input, self.postprocessing["metrics"]["input"]) target = apply_fns(target, self.postprocessing["metrics"]["target"]) pred = apply_fns(pred, self.postprocessing["metrics"]["pred"]) - # Calculate the metrics for the step. - metrics = self.metrics[mode](pred, target) - # Log the metrics for monitoring purposes. - self.log_dict(metrics, on_step=True, on_epoch=True, sync_dist=True, logger=False, batch_size=self.batch_size) + # Calculate the step metrics. + # TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203 + metrics = self.metrics["_" + mode](pred, target) if self.metrics["_" + mode] is not None else None + # Log the metrics. + if metrics is not None: + self.log_dict(metrics, on_step=True, on_epoch=True, sync_dist=True, logger=False, batch_size=self.batch_size) # Data postprocessing for logging. input = apply_fns(input, self.postprocessing["logging"]["input"]) @@ -221,7 +217,7 @@ def _base_step(self, batch: Union[List, Tuple], batch_idx: int, mode: str) -> Un pred = apply_fns(pred, self.postprocessing["logging"]["pred"]) # Return the loss, metrics, input, target, and pred. - return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred} + return {"loss": loss, "metrics": metrics, "input": input, "target": target, "pred": pred, "id": id} def _calculate_loss( self, pred: Union[torch.Tensor, List, Tuple, Dict], target: Union[torch.Tensor, List, Tuple, Dict, None] @@ -297,24 +293,18 @@ def _base_dataloader(self, mode: str) -> DataLoader: collate_fn=collate_fn, ) - def configure_optimizers(self) -> Union[Optimizer, List[Dict[str, Union[Optimizer, "Scheduler"]]]]: + def configure_optimizers(self) -> Dict: """LightningModule method. Returns optimizers and, if defined, schedulers. Returns: - Optimizer or a List of Dict of paired Optimizers and Schedulers: instantiated - optimizers and/or schedulers. + Dict: optimizer and, if defined, scheduler. """ - if not self.optimizer: - logger.error("Please specify 'system.optimizer' in the config. Exiting.") - sys.exit() - if not self.scheduler: - return self.optimizer - - if len(self.optimizer) != len(self.scheduler): - logger.error("Each optimizer must have its own scheduler.") - sys.exit() - - return [{"optimizer": opt, "lr_scheduler": sched} for opt, sched in zip(self.optimizer, self.scheduler)] + if self.optimizer is None: + raise ValueError("Please specify 'system.optimizer' in the config.") + if self.scheduler is None: + return {"optimizer": self.optimizer} + else: + return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler} def setup(self, stage: str) -> None: """Automatically called by the LightningModule after the initialization. @@ -362,7 +352,10 @@ def setup(self, stage: str) -> None: self.predict_step = partial(self._base_step, mode="predict") def _init_placeholders_for_dataloader_and_step_methods(self) -> None: - """`LighterSystem` dynamically defines the `..._dataloader()`and `..._step()` methods + """ + Initializes placeholders for dataloader and step methods. + + `LighterSystem` dynamically defines the `..._dataloader()`and `..._step()` methods in the `self.setup()` method. However, when `LightningModule` excepts them to be defined at init. To prevent it from throwing an error, the `..._dataloader()` and `..._step()` are initially defined as `lambda: None`, before `self.setup()` is called. @@ -371,3 +364,32 @@ def _init_placeholders_for_dataloader_and_step_methods(self) -> None: self.val_dataloader = self.validation_step = lambda: None self.test_dataloader = self.test_step = lambda: None self.predict_dataloader = self.predict_step = lambda: None + + def _init_datasets(self, datasets: Dict[str, Optional[Dataset]]): + """Ensures that the datasets have the predefined schema.""" + return ensure_dict_schema(datasets, {"train": None, "val": None, "test": None, "predict": None}) + + def _init_samplers(self, samplers: Dict[str, Optional[Sampler]]): + """Ensures that the samplers have the predefined schema""" + return ensure_dict_schema(samplers, {"train": None, "val": None, "test": None, "predict": None}) + + def _init_collate_fns(self, collate_fns: Dict[str, Optional[Callable]]): + """Ensures that the collate functions have the predefined schema.""" + return ensure_dict_schema(collate_fns, {"train": None, "val": None, "test": None, "predict": None}) + + def _init_metrics(self, metrics: Dict[str, Optional[Union[Metric, List[Metric], Dict[str, Metric]]]]): + """Ensures that the metrics have the predefined schema. Wraps each mode's metrics in + a MetricCollection, and finally registers them with PyTorch using a ModuleDict. + """ + metrics = ensure_dict_schema(metrics, {"train": None, "val": None, "test": None}) + for mode, metric in metrics.items(): + metrics[mode] = MetricCollection(metric) if metric is not None else None + # TODO: Remove the prefix addition line below when fixed https://github.com/pytorch/pytorch/issues/71203 + metrics = {f"_{k}": v for k, v in metrics.items()} + return ModuleDict(metrics) + + def _init_postprocessing(self, postprocessing: Dict[str, Optional[Union[Callable, List[Callable]]]]): + """Ensures that the postprocessing functions have the predefined schema.""" + subschema = {"input": None, "target": None, "pred": None} + schema = {"criterion": subschema, "metrics": subschema, "logging": subschema} + return ensure_dict_schema(postprocessing, schema) diff --git a/lighter/utils/cli.py b/lighter/utils/cli.py index 0d94276..c74b31e 100644 --- a/lighter/utils/cli.py +++ b/lighter/utils/cli.py @@ -1,4 +1,3 @@ -import sys from functools import partial import fire diff --git a/lighter/utils/collate.py b/lighter/utils/collate.py index b0bd823..5dff93a 100644 --- a/lighter/utils/collate.py +++ b/lighter/utils/collate.py @@ -1,8 +1,7 @@ -from typing import Any, Callable, List +from typing import Any, Callable import random -import torch from torch.utils.data import DataLoader from torch.utils.data._utils.collate import collate_str_fn, default_collate_fn_map from torch.utils.data.dataloader import default_collate diff --git a/lighter/utils/dynamic_imports.py b/lighter/utils/dynamic_imports.py index 155decf..bf8e57a 100644 --- a/lighter/utils/dynamic_imports.py +++ b/lighter/utils/dynamic_imports.py @@ -1,50 +1,76 @@ -from typing import Any +from typing import Dict import importlib import sys +from dataclasses import dataclass, field from pathlib import Path from loguru import logger +from monai.utils.module import optional_import -OPTIONAL_IMPORTS = {} + +@dataclass +class OptionalImports: + """Dataclass for handling optional imports. + + This class provides a way to handle optional imports in a convenient manner. + It allows importing modules that may or may not be available, and raises an ImportError if the module is not available. + + Example: :: + from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS + writer = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter() + + Attributes: + imports (Dict[str, object]): A dictionary to store the imported modules. + """ + + imports: Dict[str, object] = field(default_factory=dict) + + def __getitem__(self, module_name: str): + """Get the imported module by name. + + Args: + module_name (str): The name of the module to import. + + Raises: + ImportError: If the module is not available. + + Returns: + object: The imported module. + """ + if module_name not in self.imports: + self.imports[module_name], module_available = optional_import(module_name) + if not module_available: + raise ImportError(f"'{module_name}' is not available. Make sure that it is installed and spelled correctly.") + return self.imports[module_name] + + +OPTIONAL_IMPORTS = OptionalImports() def import_module_from_path(module_name: str, module_path: str) -> None: - """Given the path to a module, import it, and name it as specified. + """Import a module from a given path and assign it a specified name. + + This function imports a module from the specified path and assigns it the specified name. Args: - module_name (str): what to name the imported module. - module_path (str): path to the module to load. + module_name (str): The name to assign to the imported module. + module_path (str): The path to the module to import. + + Raises: + ValueError: If the module has already been imported. + FileNotFoundError: If the `__init__.py` file is not found in the module path. """ # Based on https://stackoverflow.com/a/41595552. if module_name in sys.modules: - logger.error(f"{module_path} has already been imported as module: {module_name}") - sys.exit() + raise ValueError(f"{module_name} has already been imported as module.") module_path = Path(module_path).resolve() / "__init__.py" if not module_path.is_file(): - logger.error(f"No `__init__.py` in `{module_path}`. Exiting.") - sys.exit() + raise FileNotFoundError(f"No `__init__.py` in `{module_path}`.") spec = importlib.util.spec_from_file_location(module_name, str(module_path)) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) sys.modules[module_name] = module logger.info(f"{module_path.parent} imported as '{module_name}' module.") - - -def import_attr(module_attr: str) -> Any: - """Import using dot-notation string, e.g., 'torch.nn.Module'. - - Args: - module_attr (str): dot-notation path to the attribute. - - Returns: - Any: imported attribute. - """ - # Split module from attribute name - module, attr = module_attr.rsplit(".", 1) - # Import the module - module = __import__(module, fromlist=[attr]) - # Get the attribute from the module - return getattr(module, attr) diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index 1eae7cb..ae27e2b 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -1,9 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Union import inspect -import sys - -from loguru import logger def ensure_list(vals: Any) -> List: @@ -63,8 +60,7 @@ def setattr_dot_notation(obj: Callable, attr: str, value: Any): """ if "." not in attr: if not hasattr(obj, attr): - logger.info(f"`{get_name(obj, True)}` has no attribute `{attr}`. Exiting.") - sys.exit() + raise AttributeError(f"`{get_name(obj, True)}` has no attribute `{attr}`.") setattr(obj, attr, value) # Solve recursively if the attribute is defined in dot-notation else: diff --git a/poetry.lock b/poetry.lock index 091fd1e..6fbacb4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -285,13 +285,13 @@ files = [ [[package]] name = "certifi" -version = "2023.5.7" +version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, ] [[package]] @@ -744,13 +744,13 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.32" +version = "3.1.35" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"}, - {file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"}, + {file = "GitPython-3.1.35-py3-none-any.whl", hash = "sha256:c19b4292d7a1d3c0f653858db273ff8a6614100d1eb1528b014ec97286193c09"}, + {file = "GitPython-3.1.35.tar.gz", hash = "sha256:9cbefbd1789a5fe9bcf621bb34d3f441f3a90c8461d377f84eda73e721d9b06b"}, ] [package.dependencies] @@ -1190,6 +1190,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1782,7 +1792,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.1" @@ -2255,6 +2265,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2262,8 +2273,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2280,6 +2298,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2287,6 +2306,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, diff --git a/projects/cifar10/experiments/monai_bundle_prototype.yaml b/projects/cifar10/experiments/monai_bundle_prototype.yaml index 9c97613..63f25a7 100644 --- a/projects/cifar10/experiments/monai_bundle_prototype.yaml +++ b/projects/cifar10/experiments/monai_bundle_prototype.yaml @@ -20,9 +20,8 @@ trainer: max_samples: 10 - _target_: lighter.callbacks.LighterFileWriter - write_dir: "$@project + '/predictions' " - write_format: "tensor" - write_interval: "step" # "epoch" + directory: "$@project + '/predictions' " + format: "tensor" system: _target_: lighter.LighterSystem