Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix config structure enforcing and typechecking. Add full Tuner support. #133

Merged
merged 4 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,20 @@ def setup(self, stage: str) -> None:
self.predict_dataloader = partial(self._base_dataloader, mode="predict")
self.predict_step = partial(self._base_step, mode="predict")

@property
surajpaib marked this conversation as resolved.
Show resolved Hide resolved
def learning_rate(self) -> float:
"""Get the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method."""
if len(self.optimizer.param_groups) > 1:
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
return self.optimizer.param_groups[0]["lr"]

@learning_rate.setter
def learning_rate(self, value) -> None:
"""Set the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method."""
if len(self.optimizer.param_groups) > 1:
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
self.optimizer.param_groups[0]["lr"] = value

def _init_placeholders_for_dataloader_and_step_methods(self) -> None:
"""
Initializes placeholders for dataloader and step methods.
Expand Down
16 changes: 13 additions & 3 deletions lighter/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,23 @@ def apply_fns(data: Any, fns: Union[Callable, List[Callable]]) -> Any:

def get_optimizer_stats(optimizer: Optimizer) -> Dict[str, float]:
"""
Extract learning rates and momentum values from each parameter group of the optimizer.
Extract learning rates and momentum values from an optimizer into a dictionary.

This function iterates over the parameter groups of the given optimizer and collects
the learning rate and momentum (or beta values) for each group. The collected values
are stored in a dictionary with keys formatted to indicate the optimizer type and
parameter group index (if multiple groups are present).

Args:
optimizer (Optimizer): A PyTorch optimizer.
optimizer (Optimizer): A PyTorch optimizer instance.

Returns:
Dictionary with formatted keys and values for learning rates and momentum.
Dict[str, float]: A dictionary containing the learning rates and momentum values
for each parameter group in the optimizer. The keys are formatted as:
- "optimizer/{optimizer_class_name}/lr" for learning rates
- "optimizer/{optimizer_class_name}/momentum" for momentum values
If there are multiple parameter groups, the keys will include the group index, e.g.,
"optimizer/{optimizer_class_name}/lr/group1".
"""
stats_dict = {}
for group_idx, group in enumerate(optimizer.param_groups):
Expand Down
124 changes: 80 additions & 44 deletions lighter/utils/runner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
from typing import Any

import copy
from functools import partial

import fire
from monai.bundle.config_parser import ConfigParser
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.tuner import Tuner

from lighter.system import LighterSystem
from lighter.utils.dynamic_imports import import_module_from_path

CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}}
TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"]
CONFIG_STRUCTURE = {
"project": None,
"vars": {},
"args": {
# Keys - names of the methods; values - arguments passed to them.
"fit": {},
"validate": {},
"test": {},
"predict": {},
"lr_find": {},
"scale_batch_size": {},
},
"system": {},
"trainer": {},
}


def cli() -> None:
"""Defines the command line interface for running lightning trainer's methods."""
commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES}
fire.Fire(commands)
commands = {method: partial(run, method) for method in CONFIG_STRUCTURE["args"]}
try:
fire.Fire(commands)
except TypeError as e:
if "run() takes 1 positional argument but" in str(e):
raise ValueError(
"Ensure that only one command is run at a time (e.g., 'lighter fit') and that "
"other command line arguments start with '--' (e.g., '--config', '--system#batch_size=1')."
) from e
raise


def parse_config(**kwargs) -> ConfigParser:
Expand All @@ -29,25 +52,24 @@ def parse_config(**kwargs) -> ConfigParser:
Returns:
An instance of ConfigParser with configuration and overrides merged and parsed.
"""
# Ensure a config file is specified.
config = kwargs.pop("config", None)
if config is None:
raise ValueError("'--config' not specified. Please provide a valid configuration file.")

# Read the config file and update it with overrides.
parser = ConfigParser(CONFIG_STRUCTURE, globals=False)
parser.read_config(config)
# Create a deep copy to ensure the original structure remains unaltered by ConfigParser.
structure = copy.deepcopy(CONFIG_STRUCTURE)
# Initialize the parser with the predefined structure.
parser = ConfigParser(structure, globals=False)
# Update the parser with the configuration file.
parser.update(parser.load_config_files(config))
# Update the parser with the provided cli arguments.
parser.update(kwargs)
return parser


def validate_config(parser: ConfigParser) -> None:
"""
Validates the configuration parser against predefined structures and allowed method names.

This function checks if the keys in the top-level of the configuration parser are valid according to the
CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that
correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES.
Validates the configuration parser against predefined structure.

Args:
parser (ConfigParser): The configuration parser instance to validate.
Expand All @@ -56,20 +78,28 @@ def validate_config(parser: ConfigParser) -> None:
ValueError: If there are invalid keys in the top-level configuration.
ValueError: If there are invalid method names specified in the 'args' section.
"""
# Validate parser keys against structure
root_keys = parser.get().keys()
invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
invalid_root_keys = set(parser.get()) - set(CONFIG_STRUCTURE)
if invalid_root_keys:
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}")
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {list(CONFIG_STRUCTURE)}.")

# Validate that 'args' contains only valid trainer method names.
args_keys = parser.get("args", {}).keys()
invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES)
invalid_args_keys = set(parser.get("args")) - set(CONFIG_STRUCTURE["args"])
if invalid_args_keys:
raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}")


def run(method: str, **kwargs: Any):
raise ValueError(f"Invalid key in 'args': {invalid_args_keys}. Allowed keys: {list(CONFIG_STRUCTURE['args'])}.")

typechecks = {
"project": (str, type(None)),
"vars": dict,
"system": dict,
"trainer": dict,
"args": dict,
**{f"args#{k}": dict for k in CONFIG_STRUCTURE["args"]},
}
for key, dtype in typechecks.items():
if not isinstance(parser.get(key), dtype):
raise ValueError(f"Invalid value for key '{key}'. Expected a {dtype}.")


def run(method: str, **kwargs: Any) -> None:
"""Run the trainer method.

Args:
Expand All @@ -82,30 +112,36 @@ def run(method: str, **kwargs: Any):
parser = parse_config(**kwargs)
validate_config(parser)

# Import the project folder as a module, if specified.
# Project. If specified, the give path is imported as a module.
project = parser.get_parsed_content("project")
if project is not None:
import_module_from_path("project", project)

# Get the main components from the parsed config.
# System
system = parser.get_parsed_content("system")
if not isinstance(system, LighterSystem):
raise ValueError("Expected 'system' to be an instance of 'LighterSystem'")

# Trainer
trainer = parser.get_parsed_content("trainer")
trainer_method_args = parser.get_parsed_content(f"args#{method}", default={})
if not isinstance(trainer, Trainer):
raise ValueError("Expected 'trainer' to be an instance of PyTorch Lightning 'Trainer'")

# Checks
if not isinstance(system, LighterSystem):
raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.")
if not hasattr(trainer, method):
raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.")
if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args):
raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.")

# Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined.
config = parser.get()
config.pop("_meta_") # MONAI Bundle adds this automatically, remove it.
system.save_hyperparameters(config)
if trainer.logger is not None:
trainer.logger.log_hyperparams(config)
# Trainer/Tuner method arguments.
method_args = parser.get_parsed_content(f"args#{method}")
if any("dataloaders" in key or "datamodule" in key for key in method_args):
raise ValueError("Datasets are defined within the 'system', not passed in `args`.")

# Run the trainer method.
getattr(trainer, method)(system, **trainer_method_args)
# Save the config to checkpoints under "hyper_parameters". Log it if a logger is defined.
system.save_hyperparameters(parser.get())
if trainer.logger is not None:
trainer.logger.log_hyperparams(parser.get())

# Run the trainer/tuner method.
if hasattr(trainer, method):
getattr(trainer, method)(system, **method_args)
elif hasattr(Tuner, method):
tuner = Tuner(trainer)
getattr(tuner, method)(system, **method_args)
else:
raise ValueError(f"Method '{method}' is not a valid Trainer or Tuner method [{list(CONFIG_STRUCTURE['args'])}].")
Loading