diff --git a/lighter/system.py b/lighter/system.py index 94404e3..c1092aa 100644 --- a/lighter/system.py +++ b/lighter/system.py @@ -361,6 +361,20 @@ def setup(self, stage: str) -> None: self.predict_dataloader = partial(self._base_dataloader, mode="predict") self.predict_step = partial(self._base_step, mode="predict") + @property + def learning_rate(self) -> float: + """Get the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method.""" + if len(self.optimizer.param_groups) > 1: + raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.") + return self.optimizer.param_groups[0]["lr"] + + @learning_rate.setter + def learning_rate(self, value) -> None: + """Set the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method.""" + if len(self.optimizer.param_groups) > 1: + raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.") + self.optimizer.param_groups[0]["lr"] = value + def _init_placeholders_for_dataloader_and_step_methods(self) -> None: """ Initializes placeholders for dataloader and step methods. diff --git a/lighter/utils/misc.py b/lighter/utils/misc.py index abd1f05..839bac5 100644 --- a/lighter/utils/misc.py +++ b/lighter/utils/misc.py @@ -119,13 +119,23 @@ def apply_fns(data: Any, fns: Union[Callable, List[Callable]]) -> Any: def get_optimizer_stats(optimizer: Optimizer) -> Dict[str, float]: """ - Extract learning rates and momentum values from each parameter group of the optimizer. + Extract learning rates and momentum values from an optimizer into a dictionary. + + This function iterates over the parameter groups of the given optimizer and collects + the learning rate and momentum (or beta values) for each group. The collected values + are stored in a dictionary with keys formatted to indicate the optimizer type and + parameter group index (if multiple groups are present). Args: - optimizer (Optimizer): A PyTorch optimizer. + optimizer (Optimizer): A PyTorch optimizer instance. Returns: - Dictionary with formatted keys and values for learning rates and momentum. + Dict[str, float]: A dictionary containing the learning rates and momentum values + for each parameter group in the optimizer. The keys are formatted as: + - "optimizer/{optimizer_class_name}/lr" for learning rates + - "optimizer/{optimizer_class_name}/momentum" for momentum values + If there are multiple parameter groups, the keys will include the group index, e.g., + "optimizer/{optimizer_class_name}/lr/group1". """ stats_dict = {} for group_idx, group in enumerate(optimizer.param_groups): diff --git a/lighter/utils/runner.py b/lighter/utils/runner.py index 130e28b..6786bd8 100644 --- a/lighter/utils/runner.py +++ b/lighter/utils/runner.py @@ -1,22 +1,45 @@ from typing import Any +import copy from functools import partial import fire from monai.bundle.config_parser import ConfigParser -from pytorch_lightning import seed_everything +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.tuner import Tuner from lighter.system import LighterSystem from lighter.utils.dynamic_imports import import_module_from_path -CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}} -TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"] +CONFIG_STRUCTURE = { + "project": None, + "vars": {}, + "args": { + # Keys - names of the methods; values - arguments passed to them. + "fit": {}, + "validate": {}, + "test": {}, + "predict": {}, + "lr_find": {}, + "scale_batch_size": {}, + }, + "system": {}, + "trainer": {}, +} def cli() -> None: """Defines the command line interface for running lightning trainer's methods.""" - commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES} - fire.Fire(commands) + commands = {method: partial(run, method) for method in CONFIG_STRUCTURE["args"]} + try: + fire.Fire(commands) + except TypeError as e: + if "run() takes 1 positional argument but" in str(e): + raise ValueError( + "Ensure that only one command is run at a time (e.g., 'lighter fit') and that " + "other command line arguments start with '--' (e.g., '--config', '--system#batch_size=1')." + ) from e + raise def parse_config(**kwargs) -> ConfigParser: @@ -29,25 +52,24 @@ def parse_config(**kwargs) -> ConfigParser: Returns: An instance of ConfigParser with configuration and overrides merged and parsed. """ - # Ensure a config file is specified. config = kwargs.pop("config", None) if config is None: raise ValueError("'--config' not specified. Please provide a valid configuration file.") - # Read the config file and update it with overrides. - parser = ConfigParser(CONFIG_STRUCTURE, globals=False) - parser.read_config(config) + # Create a deep copy to ensure the original structure remains unaltered by ConfigParser. + structure = copy.deepcopy(CONFIG_STRUCTURE) + # Initialize the parser with the predefined structure. + parser = ConfigParser(structure, globals=False) + # Update the parser with the configuration file. + parser.update(parser.load_config_files(config)) + # Update the parser with the provided cli arguments. parser.update(kwargs) return parser def validate_config(parser: ConfigParser) -> None: """ - Validates the configuration parser against predefined structures and allowed method names. - - This function checks if the keys in the top-level of the configuration parser are valid according to the - CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that - correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES. + Validates the configuration parser against predefined structure. Args: parser (ConfigParser): The configuration parser instance to validate. @@ -56,20 +78,28 @@ def validate_config(parser: ConfigParser) -> None: ValueError: If there are invalid keys in the top-level configuration. ValueError: If there are invalid method names specified in the 'args' section. """ - # Validate parser keys against structure - root_keys = parser.get().keys() - invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"} + invalid_root_keys = set(parser.get()) - set(CONFIG_STRUCTURE) if invalid_root_keys: - raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}") + raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {list(CONFIG_STRUCTURE)}.") - # Validate that 'args' contains only valid trainer method names. - args_keys = parser.get("args", {}).keys() - invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES) + invalid_args_keys = set(parser.get("args")) - set(CONFIG_STRUCTURE["args"]) if invalid_args_keys: - raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}") - - -def run(method: str, **kwargs: Any): + raise ValueError(f"Invalid key in 'args': {invalid_args_keys}. Allowed keys: {list(CONFIG_STRUCTURE['args'])}.") + + typechecks = { + "project": (str, type(None)), + "vars": dict, + "system": dict, + "trainer": dict, + "args": dict, + **{f"args#{k}": dict for k in CONFIG_STRUCTURE["args"]}, + } + for key, dtype in typechecks.items(): + if not isinstance(parser.get(key), dtype): + raise ValueError(f"Invalid value for key '{key}'. Expected a {dtype}.") + + +def run(method: str, **kwargs: Any) -> None: """Run the trainer method. Args: @@ -82,30 +112,36 @@ def run(method: str, **kwargs: Any): parser = parse_config(**kwargs) validate_config(parser) - # Import the project folder as a module, if specified. + # Project. If specified, the give path is imported as a module. project = parser.get_parsed_content("project") if project is not None: import_module_from_path("project", project) - # Get the main components from the parsed config. + # System system = parser.get_parsed_content("system") + if not isinstance(system, LighterSystem): + raise ValueError("Expected 'system' to be an instance of 'LighterSystem'") + + # Trainer trainer = parser.get_parsed_content("trainer") - trainer_method_args = parser.get_parsed_content(f"args#{method}", default={}) + if not isinstance(trainer, Trainer): + raise ValueError("Expected 'trainer' to be an instance of PyTorch Lightning 'Trainer'") - # Checks - if not isinstance(system, LighterSystem): - raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.") - if not hasattr(trainer, method): - raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.") - if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args): - raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.") - - # Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined. - config = parser.get() - config.pop("_meta_") # MONAI Bundle adds this automatically, remove it. - system.save_hyperparameters(config) - if trainer.logger is not None: - trainer.logger.log_hyperparams(config) + # Trainer/Tuner method arguments. + method_args = parser.get_parsed_content(f"args#{method}") + if any("dataloaders" in key or "datamodule" in key for key in method_args): + raise ValueError("Datasets are defined within the 'system', not passed in `args`.") - # Run the trainer method. - getattr(trainer, method)(system, **trainer_method_args) + # Save the config to checkpoints under "hyper_parameters". Log it if a logger is defined. + system.save_hyperparameters(parser.get()) + if trainer.logger is not None: + trainer.logger.log_hyperparams(parser.get()) + + # Run the trainer/tuner method. + if hasattr(trainer, method): + getattr(trainer, method)(system, **method_args) + elif hasattr(Tuner, method): + tuner = Tuner(trainer) + getattr(tuner, method)(system, **method_args) + else: + raise ValueError(f"Method '{method}' is not a valid Trainer or Tuner method [{list(CONFIG_STRUCTURE['args'])}].")