Skip to content

Commit

Permalink
Add "id" support. Refactor Writers. Add Writer additional format exte…
Browse files Browse the repository at this point in the history
…nsibility. (#78)

* Remove loss logging when predicting

* Add "id" for each batch sample ID-ing purposes. Refactor Writers, add easy extensibility for new formats

* Remove interval arg, group_tensors fn, and on pred epoch end writing. Add decollate_batch when writing.

* Small fixes

* Remove multi opt and scheduler support. Replace remaininig sys.exit's.

* Update configure_optimizers docstring

* Fix index ID issue in DDP writing. Replace broadcast with gather in the TableWriter.

* Add missing if DDP check

* Update docstrings, rename and refactor parse_data

* Add freezer to init file

* Change property to attribute

* Add support for dict metrics. Refactor system.

* Fix typos

* Remove unused imports

* Update logger.py to support the temp ModuleDict fix

* Add continue to freezer and detach cpu to image logging

* Remove multi_pred, refactor Writer, Logger, and optional imports

* Bump gitpython from 3.1.32 to 3.1.35

Bumps [gitpython](https://github.com/gitpython-developers/GitPython) from 3.1.32 to 3.1.35.
- [Release notes](https://github.com/gitpython-developers/GitPython/releases)
- [Changelog](https://github.com/gitpython-developers/GitPython/blob/main/CHANGES)
- [Commits](gitpython-developers/GitPython@3.1.32...3.1.35)

---
updated-dependencies:
- dependency-name: gitpython
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>

* Bump certifi from 2023.5.7 to 2023.7.22

Bumps [certifi](https://github.com/certifi/python-certifi) from 2023.5.7 to 2023.7.22.
- [Commits](certifi/python-certifi@2023.05.07...2023.07.22)

---
updated-dependencies:
- dependency-name: certifi
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>

* Remove add_batch_dim

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
ibro45 and dependabot[bot] authored Sep 15, 2023
1 parent 7de0b58 commit a22442f
Show file tree
Hide file tree
Showing 14 changed files with 508 additions and 560 deletions.
1 change: 1 addition & 0 deletions lighter/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .freezer import LighterFreezer
from .logger import LighterLogger
from .writer.file import LighterFileWriter
from .writer.table import LighterTableWriter
15 changes: 9 additions & 6 deletions lighter/callbacks/freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def on_test_batch_start(
self._on_batch_start(trainer, pl_module)

def on_predict_batch_start(
self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
self._on_batch_start(trainer, pl_module)

Expand Down Expand Up @@ -122,20 +122,23 @@ def _set_model_requires_grad(self, model: Union[Module, LighterSystem], requires
# Leave the excluded-from-freezing parameters trainable.
if self.except_names and name in self.except_names:
param.requires_grad = True
elif self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with):
continue
if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with):
param.requires_grad = True
continue

# Freeze/unfreeze the specified parameters, based on the `requires_grad` argument.
elif self.names and name in self.names:
if self.names and name in self.names:
param.requires_grad = requires_grad
frozen_layers.append(name)
elif self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with):
continue
if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with):
param.requires_grad = requires_grad
frozen_layers.append(name)
continue

# Otherwise, leave the parameter trainable.
else:
param.requires_grad = True
param.requires_grad = True

self._frozen_state = not requires_grad
# Log only when freezing the parameters.
Expand Down
152 changes: 73 additions & 79 deletions lighter/callbacks/logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Any, Dict, Union

import itertools
import sys
from datetime import datetime
from pathlib import Path

import torch
from loguru import logger
from monai.utils.module import optional_import
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor

from lighter import LighterSystem
from lighter.callbacks.utils import get_lighter_mode, is_data_type_supported, parse_data, preprocess_image
from lighter.callbacks.utils import get_lighter_mode, preprocess_image
from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS


Expand Down Expand Up @@ -62,8 +60,7 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
stage (str): stage of the training process. Passed automatically by PyTorch Lightning.
"""
if trainer.logger is not None:
logger.error("When using LighterLogger, set Trainer(logger=None).")
sys.exit()
raise ValueError("When using LighterLogger, set Trainer(logger=None).")

if not trainer.is_global_zero:
return
Expand All @@ -76,8 +73,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:

# Tensorboard initialization.
if self.tensorboard:
# Tensorboard is a part of PyTorch, no need to check if it is not available.
OPTIONAL_IMPORTS["tensorboard"], _ = optional_import("torch.utils.tensorboard")
tensorboard_dir = self.log_dir / "tensorboard"
tensorboard_dir.mkdir()
self.tensorboard = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter(log_dir=tensorboard_dir)
Expand All @@ -86,10 +81,6 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:

# Wandb initialization.
if self.wandb:
OPTIONAL_IMPORTS["wandb"], wandb_available = optional_import("wandb")
if not wandb_available:
logger.error("Weights & Biases not installed. To install it, run `pip install wandb`. Exiting.")
sys.exit()
wandb_dir = self.log_dir / "wandb"
wandb_dir.mkdir()
self.wandb = OPTIONAL_IMPORTS["wandb"].init(project=self.project, dir=wandb_dir, config=self.config)
Expand Down Expand Up @@ -165,6 +156,7 @@ def _log_image(self, name: str, image: torch.Tensor, global_step: int) -> None:
image (torch.Tensor): image to be logged.
global_step (int): current global step.
"""
image = image.detach().cpu()
if self.tensorboard:
self.tensorboard.add_image(name, image, global_step=global_step)
if self.wandb:
Expand All @@ -179,7 +171,6 @@ def _log_histogram(self, name: str, tensor: torch.Tensor, global_step: int) -> N
global_step (int): current global step.
"""
tensor = tensor.detach().cpu()

if self.tensorboard:
self.tensorboard.add_histogram(name, tensor, global_step=global_step)
if self.wandb:
Expand All @@ -193,49 +184,44 @@ def _on_batch_end(self, outputs: Dict, trainer: Trainer) -> None:
outputs (Dict): output dict from the model.
trainer (Trainer): Trainer, passed automatically by PyTorch Lightning.
"""
if not trainer.sanity_checking:
mode = get_lighter_mode(trainer.state.stage)
# Accumulate the loss.
if mode in ["train", "val"]:
self.loss[mode] += outputs["loss"].item()
# Logging frequency. Log only on rank 0.
if trainer.is_global_zero and self.global_step_counter[mode] % trainer.log_every_n_steps == 0:
# Get global step.
global_step = self._get_global_step(trainer)

# Log loss.
if outputs["loss"] is not None:
self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step)

# Log metrics.
if outputs["metrics"] is not None:
for name, metric in outputs["metrics"].items():
self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step)

# Log input, target, and pred.
for name in ["input", "target", "pred"]:
if self.log_types[name] is None:
continue
# Ensure data is of a valid type.
if not is_data_type_supported(outputs[name]):
raise ValueError(
f"`{name}` has to be a Tensor, List[Tensor], Tuple[Tensor], Dict[str, Tensor], "
f"Dict[str, List[Tensor]], or Dict[str, Tuple[Tensor]]. `{type(outputs[name])}` is not supported."
)
for identifier, item in parse_data(outputs[name]).items():
item_name = f"{mode}/data/{name}" if identifier is None else f"{mode}/data/{name}_{identifier}"
self._log_by_type(item_name, item, self.log_types[name], global_step)

# Log learning rate stats. Logs at step if a scheduler's interval is step-based.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "step")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step)

# Increment the step counters.
self.global_step_counter[mode] += 1
if mode in ["train", "val"]:
self.epoch_step_counter[mode] += 1
if trainer.sanity_checking:
return

mode = get_lighter_mode(trainer.state.stage)

# Accumulate the loss.
if mode in ["train", "val"]:
self.loss[mode] += outputs["loss"].item()

# Log only on rank 0 and according to the `log_every_n_steps` parameter. Otherwise, only increment the step counters.
if not trainer.is_global_zero or self.global_step_counter[mode] % trainer.log_every_n_steps != 0:
self._increment_step_counters(mode)
return

global_step = self._get_global_step(trainer)

# Loss.
if outputs["loss"] is not None:
self._log_scalar(f"{mode}/loss/step", outputs["loss"], global_step)

# Metrics.
if outputs["metrics"] is not None:
for name, metric in outputs["metrics"].items():
self._log_scalar(f"{mode}/metrics/{name}/step", metric, global_step)

# Input, target, and pred.
for name in ["input", "target", "pred"]:
if self.log_types[name] is not None:
self._log_by_type(f"{mode}/data/{name}", outputs[name], self.log_types[name], global_step)

# LR info. Logs at step if a scheduler's interval is step-based.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "step")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/step", value, global_step)

# Increment the step counters.
self._increment_step_counters(mode)

def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
"""Performs logging at the end of an epoch. Logs the epoch number, the loss, and the metrics.
Expand All @@ -249,46 +235,44 @@ def _on_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> None:
mode = get_lighter_mode(trainer.state.stage)
loss, metrics = None, None

# Loss
# Get the accumulated loss over the epoch and processes.
if mode in ["train", "val"]:
# Get the accumulated loss.
loss = self.loss[mode]
# Reduce the loss and average it on each rank.
loss = trainer.strategy.reduce(loss, reduce_op="mean")
# Divide the accumulated loss by the number of steps in the epoch.
loss /= self.epoch_step_counter[mode]

# Metrics
# Get the torchmetrics.
metric_collection = pl_module.metrics[mode]
# TODO: Remove the "_" prefix when fixed https://github.com/pytorch/pytorch/issues/71203
metric_collection = pl_module.metrics["_" + mode]
if metric_collection is not None:
# Compute the epoch metrics.
metrics = metric_collection.compute()
# Reset the metrics for the next epoch.
metric_collection.reset()

# Log. Only on rank 0.
if trainer.is_global_zero:
# Get global step.
global_step = self._get_global_step(trainer)
# Log only on rank 0.
if not trainer.is_global_zero:
return

# Log epoch number.
self._log_scalar("epoch", trainer.current_epoch, global_step)
global_step = self._get_global_step(trainer)

# Log loss.
if loss is not None:
self._log_scalar(f"{mode}/loss/epoch", loss, global_step)
# Epoch number.
self._log_scalar("epoch", trainer.current_epoch, global_step)

# Log metrics.
if metrics is not None:
for name, metric in metrics.items():
self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step)
# Loss.
if loss is not None:
self._log_scalar(f"{mode}/loss/epoch", loss, global_step)

# Log learning rate stats. Logs at epoch if a scheduler's interval is epoch-based, or if no scheduler is used.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "epoch")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step)
# Metrics.
if metrics is not None:
for name, metric in metrics.items():
self._log_scalar(f"{mode}/metrics/{name}/epoch", metric, global_step)

# LR info. Logged at epoch if the scheduler's interval is epoch-based, or if no scheduler is used.
if mode == "train":
lr_stats = self.lr_monitor.get_stats(trainer, "epoch")
for name, value in lr_stats.items():
self._log_scalar(f"{mode}/optimizer/{name}/epoch", value, global_step)

def _get_global_step(self, trainer: Trainer) -> int:
"""Return the global step for the current mode. Note that when Trainer
Expand All @@ -309,6 +293,16 @@ def _get_global_step(self, trainer: Trainer) -> int:
return self.global_step_counter["train"]
return self.global_step_counter[mode]

def _increment_step_counters(self, mode: str) -> None:
"""Increment the global step and epoch step counters for the specified mode.
Args:
mode (str): mode to increment the global step counter for.
"""
self.global_step_counter[mode] += 1
if mode in ["train", "val"]:
self.epoch_step_counter[mode] += 1

def on_train_epoch_start(self, trainer: Trainer, pl_module: LighterSystem) -> None:
# Reset the loss and the epoch step counter for the next epoch.
self.loss["train"] = 0
Expand Down
Loading

0 comments on commit a22442f

Please sign in to comment.