diff --git a/requirements.dev.txt b/requirements.dev.txt index 2e6b258..bd93ba2 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,4 +1,4 @@ -black +black==24.2.0 coverage isort pytest diff --git a/tests/test_train_gan.py b/tests/test_train_gan.py index 06a3ec4..37c7fc0 100644 --- a/tests/test_train_gan.py +++ b/tests/test_train_gan.py @@ -79,8 +79,7 @@ def __init__(self): self.generator = Generator(latent_dim=100, img_shape=data_shape) self.discriminator = Discriminator(img_shape=data_shape) - def forward(self, x): - ... + def forward(self, x): ... def train_step(self, batch, criterion, optimizer_idx): imgs, _ = batch @@ -174,8 +173,7 @@ def __init__(self): self.generator = Generator(latent_dim=100, img_shape=data_shape) self.discriminator = Discriminator(img_shape=data_shape) - def forward(self, x): - ... + def forward(self, x): ... def train_step(self, batch, criterion, optimizer_idx): imgs, _ = batch @@ -271,8 +269,7 @@ def __init__(self): self.generator = Generator(latent_dim=100, img_shape=data_shape) self.discriminator = Discriminator(img_shape=data_shape) - def forward(self, x): - ... + def forward(self, x): ... def optimize(self, batch, trainer): imgs, _ = batch @@ -385,8 +382,7 @@ def __init__(self): self.generator = Generator(latent_dim=100, img_shape=data_shape) self.discriminator = Discriminator(img_shape=data_shape) - def forward(self, x): - ... + def forward(self, x): ... def optimize(self, batch, trainer): imgs, _ = batch @@ -502,11 +498,9 @@ def __init__(self): self.generator = Generator(latent_dim=100, img_shape=data_shape) self.discriminator = Discriminator(img_shape=data_shape) - def train_step(): - ... + def train_step(): ... - def forward(self, x): - ... + def forward(self, x): ... def optimize(self, batch, trainer): imgs, _ = batch diff --git a/trainer/VERSION b/trainer/VERSION index b82608c..8308b63 100644 --- a/trainer/VERSION +++ b/trainer/VERSION @@ -1 +1 @@ -v0.1.0 +v0.1.1 diff --git a/trainer/trainer.py b/trainer/trainer.py index 948a9dd..229c4c9 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import gc import importlib import logging @@ -11,7 +10,7 @@ from contextlib import nullcontext from dataclasses import dataclass, field from inspect import signature -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -63,6 +62,7 @@ @dataclass class TrainerConfig(Coqpit): """Config fields tweaking the Trainer for a model. + A ````ModelConfig```, by inheriting ```TrainerConfig``` must be defined for using πŸ‘Ÿ. Inherit this by a new model config and override the fields as needed. All the fields can be overridden from comman-line as ```--coqpit.arg_name=value```. @@ -134,9 +134,7 @@ class TrainerConfig(Coqpit): save_all_best: bool = field( default=False, metadata={"help": "Save all best checkpoints and keep the older ones. Defaults to False"} ) - save_best_after: int = field( - default=0, metadata={"help": "Wait N steps to save best checkpoints. Defaults to 0"} - ) + save_best_after: int = field(default=0, metadata={"help": "Wait N steps to save best checkpoints. Defaults to 0"}) target_loss: str = field( default=None, metadata={"help": "Target loss name to select the best model. Defaults to None"} ) @@ -306,8 +304,9 @@ def __init__( # pylint: disable=dangerous-default-value callbacks: Dict[str, Callable] = {}, gpu: int = None, ) -> None: - """Simple yet powerful πŸΈπŸ’¬ TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models - or easily be customized. + """Simple yet powerful πŸΈπŸ’¬ TTS trainer for PyTorch. + + It can train all the available `tts` and `vocoder` models or easily be customized. Notes: @@ -521,9 +520,8 @@ def __init__( # pylint: disable=dangerous-default-value for criterion in self.criterion: if isinstance(criterion, torch.nn.Module): criterion.cuda() - else: - if isinstance(self.criterion, torch.nn.Module): - self.criterion.cuda() + elif isinstance(self.criterion, torch.nn.Module): + self.criterion.cuda() # setup optimizer self.optimizer = self.get_optimizer(self.model, self.config) @@ -581,21 +579,21 @@ def __init__( # pylint: disable=dangerous-default-value self.save_training_script() @property - def use_apex(self): + def use_apex(self) -> bool: """Return True if using APEX.""" return not self.args.use_accelerate and self._is_apex_available() @property - def use_pt_ddp(self): + def use_pt_ddp(self) -> bool: """Return True if using PyTorch DDP.""" return self.num_gpus > 1 and not self.use_accelerate @property - def use_accelerate(self): + def use_accelerate(self) -> bool: """Return True if using HF Accelerate.""" return self.args.use_accelerate - def setup_accelerate(self): + def setup_accelerate(self) -> None: if self.use_accelerate: self.model, self.optimizer, self.train_loader, self.scheduler, self.accelerator = self.init_accelerate( model=self.model, @@ -657,7 +655,7 @@ def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum return model, optimizer, training_dataloader, scheduler, accelerator - def save_training_script(self): + def save_training_script(self) -> None: """Save the training script to tracking dashboard and output path.""" file_path = sys.argv[0] if os.path.isfile(file_path): @@ -682,6 +680,7 @@ def parse_argv(args: Union[Coqpit, List]): @staticmethod def init_loggers(config: "Coqpit", output_path: str, dashboard_logger=None, c_logger=None): """Init console and dashboard loggers. + Use the given logger if passed externally else use config values to pick the right logger. Return a dashboard logger only for the rank 0 process in DDP Define a console logger for each process in DDP @@ -704,7 +703,7 @@ def init_loggers(config: "Coqpit", output_path: str, dashboard_logger=None, c_lo dashboard_logger = logger_factory(config, output_path) return dashboard_logger, c_logger - def setup_small_run(self, small_run: int = None): + def setup_small_run(self, small_run: Optional[int] = None) -> None: """Use a subset of samples for training, evaluation and testing.""" if small_run is not None: logger.info("[!] Small Run, only using %i samples.", small_run) @@ -713,7 +712,9 @@ def setup_small_run(self, small_run: int = None): self.test_samples = None if self.test_samples is None else self.test_samples[:small_run] @staticmethod - def init_training(args: TrainerArgs, coqpit_overrides: Dict, config: Coqpit = None): + def init_training( + args: TrainerArgs, coqpit_overrides: Dict, config: Coqpit = None + ) -> Tuple[Coqpit, Dict[str, str]]: """Initialize training and update model configs from command line arguments. Args: @@ -751,7 +752,7 @@ def init_training(args: TrainerArgs, coqpit_overrides: Dict, config: Coqpit = No return config, new_fields @staticmethod - def setup_training_environment(args, config, gpu): + def setup_training_environment(args, config, gpu) -> Tuple[bool, int]: if platform.system() != "Windows": # https://github.com/pytorch/pytorch/issues/973 import resource # pylint: disable=import-outside-toplevel @@ -886,7 +887,7 @@ def _get_loader( model: nn.Module, config: Coqpit, assets: Dict, - is_eval: str, + is_eval: bool, samples: List, verbose: bool, num_gpus: int, @@ -902,11 +903,10 @@ def _get_loader( num_gpus, self.args.rank, ) - else: - if isimplemented(model, "get_data_loader"): - loader = model.get_data_loader( - config=config, assets=assets, is_eval=is_eval, samples=samples, verbose=verbose, num_gpus=num_gpus - ) + elif isimplemented(model, "get_data_loader"): + loader = model.get_data_loader( + config=config, assets=assets, is_eval=is_eval, samples=samples, verbose=verbose, num_gpus=num_gpus + ) assert ( len(loader) > 0 @@ -915,6 +915,7 @@ def _get_loader( def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: """Initialize and return a training data loader. + Call ```model.get_train_data_loader``` if it is implemented, else call ```model.get_data_loader``` and set ```is_eval=False```. @@ -928,7 +929,7 @@ def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bo """ if self.num_gpus > 1: if isimplemented(self.model.module, "get_train_data_loader"): - loader = self.model.module.get_train_data_loader( + return self.model.module.get_train_data_loader( self.config, self.training_assets, samples, @@ -936,13 +937,8 @@ def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bo self.num_gpus, self.args.rank, ) - return loader - else: - if isimplemented(self.model, "get_train_data_loader"): - loader = self.model.get_train_data_loader( - self.config, self.training_assets, samples, verbose, self.num_gpus - ) - return loader + elif isimplemented(self.model, "get_train_data_loader"): + return self.model.get_train_data_loader(self.config, self.training_assets, samples, verbose, self.num_gpus) return self._get_loader( self.model, @@ -956,6 +952,7 @@ def get_train_dataloader(self, training_assets: Dict, samples: List, verbose: bo def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: """Initialize and return a evaluation data loader. + Call ```model.get_eval_data_loader``` if it is implemented, else call ```model.get_data_loader``` and set ```is_eval=True```. @@ -969,7 +966,7 @@ def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: boo """ if self.num_gpus > 1: if isimplemented(self.model.module, "get_eval_data_loader"): - loader = self.model.module.get_eval_data_loader( + return self.model.module.get_eval_data_loader( self.config, self.training_assets, samples, @@ -977,13 +974,8 @@ def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: boo self.num_gpus, self.args.rank, ) - return loader - else: - if isimplemented(self.model, "get_eval_data_loader"): - loader = self.model.get_eval_data_loader( - self.config, self.training_assets, samples, verbose, self.num_gpus - ) - return loader + elif isimplemented(self.model, "get_eval_data_loader"): + return self.model.get_eval_data_loader(self.config, self.training_assets, samples, verbose, self.num_gpus) return self._get_loader( self.model, @@ -997,6 +989,7 @@ def get_eval_dataloader(self, training_assets: Dict, samples: List, verbose: boo def get_test_dataloader(self, training_assets: Dict, samples: List, verbose: bool) -> DataLoader: """Initialize and return a evaluation data loader. + Call ```model.get_test_data_loader``` if it is implemented, else call ```model.get_data_loader``` and set ```is_eval=True```. @@ -1010,7 +1003,7 @@ def get_test_dataloader(self, training_assets: Dict, samples: List, verbose: boo """ if self.num_gpus > 1: if isimplemented(self.model.module, "get_test_data_loader"): - loader = self.model.module.get_test_data_loader( + return self.model.module.get_test_data_loader( self.config, self.training_assets, samples, @@ -1018,13 +1011,8 @@ def get_test_dataloader(self, training_assets: Dict, samples: List, verbose: boo self.num_gpus, self.args.rank, ) - return loader - else: - if isimplemented(self.model, "get_test_data_loader"): - loader = self.model.get_test_data_loader( - self.config, self.training_assets, samples, verbose, self.num_gpus - ) - return loader + elif isimplemented(self.model, "get_test_data_loader"): + return self.model.get_test_data_loader(self.config, self.training_assets, samples, verbose, self.num_gpus) return self._get_loader( self.model, @@ -1090,10 +1078,9 @@ def master_params(optimizer: torch.optim.Optimizer): @staticmethod def _model_train_step( - batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None + batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None ) -> Tuple[Dict, Dict]: - """ - Perform a trainig forward step. Compute model outputs and losses. + """Perform a trainig forward step. Compute model outputs and losses. Args: batch (Dict): [description] @@ -1130,7 +1117,11 @@ def _get_autocast_args(self, mixed_precision: bool, precision: str): return device, dtype def detach_loss_dict( - self, loss_dict: Dict, step_optimizer: bool, optimizer_idx: int = None, grad_norm: float = None + self, + loss_dict: Dict, + step_optimizer: bool, + optimizer_idx: Optional[int] = None, + grad_norm: Optional[float] = None, ): # detach losses for logging loss_dict_detached = self._detach_loss_dict(loss_dict) @@ -1191,7 +1182,7 @@ def optimize( criterion: nn.Module, scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List, Dict], # pylint: disable=protected-access config: Coqpit, - optimizer_idx: int = None, + optimizer_idx: Optional[int] = None, step_optimizer: bool = True, num_optimizers: int = 1, ) -> Tuple[Dict, Dict, int]: @@ -1216,7 +1207,6 @@ def optimize( Returns: Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. """ - step_start_time = time.time() # forward pass and loss computation @@ -1543,10 +1533,9 @@ def train_epoch(self) -> None: ####################### def _model_eval_step( - self, batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None + self, batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: Optional[int] = None ) -> Tuple[Dict, Dict]: - """ - Perform a evaluation forward pass. Compute model outputs and losses with no gradients. + """Perform a evaluation forward pass. Compute model outputs and losses with no gradients. Args: batch (Dict): IBatch of inputs. @@ -1714,9 +1703,12 @@ def test_run(self) -> None: else: self.model.test_log(test_outputs, self.dashboard_logger, self.training_assets, self.total_steps_done) - def _restore_best_loss(self): - """Restore the best loss from the args.best_path if provided else - from the model (`args.continue_path`) used for resuming the training""" + def _restore_best_loss(self) -> None: + """Restore the best loss. + + Restore from the args.best_path if provided else from the model + (`args.continue_path`) used for resuming the training. + """ if self.args.continue_path and (self.restore_step != 0 or self.args.best_path): logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path)) ch = load_fsspec(self.args.restore_path, map_location="cpu") @@ -1732,8 +1724,10 @@ def _restore_best_loss(self): logger.info(" > Starting with loaded last best loss %s", self.best_loss) def test(self, model=None, test_samples=None) -> None: - """Run evaluation steps on the test data split. You can either provide the model and the test samples - explicitly or the trainer use values from the initialization. + """Run evaluation steps on the test data split. + + You can either provide the model and the test samples + explicitly or the trainer uses values from the initialization. Args: model (nn.Module, optional): Model to use for testing. If None, use the model given in the initialization. @@ -1742,7 +1736,6 @@ def test(self, model=None, test_samples=None) -> None: test_samples (List[str], optional): List of test samples to use for testing. If None, use the test samples given in the initialization. Defaults to None. """ - logger.info(" > USING TEST SET...") self.keep_avg_eval = KeepAverage() @@ -1794,7 +1787,7 @@ def _fit(self) -> None: self.callbacks.on_epoch_end(self) self.start_with_eval = False - def fit_with_largest_batch_size(self, starting_batch_size=2048) -> None: + def fit_with_largest_batch_size(self, starting_batch_size: int = 2048) -> None: cuda_meminfo() bs = starting_batch_size while True: @@ -1850,15 +1843,15 @@ def fit(self) -> None: self.dashboard_logger.finish() # stop without error signal try: - sys.exit(1) + sys.exit(130) except SystemExit: - os._exit(1) # pylint: disable=protected-access + os._exit(130) # pylint: disable=protected-access except BaseException: # pylint: disable=broad-except remove_experiment_folder(self.output_path) traceback.print_exc() sys.exit(1) - def profile_fit(self, torch_profiler, epochs=None, small_run=None): + def profile_fit(self, torch_profiler, epochs: Optional[int] = None, small_run: Optional[int] = None): """Run training under the torch profiler. Example:: @@ -1881,14 +1874,13 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None): self.dashboard_logger = DummyLogger() # train the model for a custom number of epochs if epochs: - self.config.epocshs = epochs + self.config.epochs = epochs # use a smaller set of training samples for profiling if small_run: self.setup_small_run(small_run) # run profiler self.config.run_eval = False self.config.test_delay_epochs = 9999999 - self.config.epochs = epochs # set a callback to progress the profiler self.callbacks_on_train_step_end = [ # pylint: disable=attribute-defined-outside-init lambda trainer: trainer.torch_profiler.step() @@ -1905,7 +1897,6 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None): @rank_zero_only def save_best_model(self) -> None: """Save the best model. It only saves if the current target loss is smaller then the previous.""" - eval_loss = self._pick_target_avg_loss(self.keep_avg_eval) train_loss = self._pick_target_avg_loss(self.keep_avg_train) @@ -1945,7 +1936,7 @@ def save_checkpoint(self) -> None: ) @rank_zero_only - def update_training_dashboard_logger(self, batch=None, outputs=None): + def update_training_dashboard_logger(self, batch=None, outputs=None) -> None: aliases = [ f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}", @@ -2058,27 +2049,25 @@ def restore_scheduler( scheduler: Union["Scheduler", List, Dict], args: Coqpit, config: Coqpit, restore_epoch: int, restore_step: int ) -> Union["Scheduler", List]: """Restore scheduler wrt restored model.""" - if scheduler is not None: # pylint: disable=too-many-nested-blocks - if args.continue_path: - if isinstance(scheduler, list): - for s in scheduler: - if s is not None: - if config.scheduler_after_epoch: - s.last_epoch = restore_epoch - else: - s.last_epoch = restore_step - elif isinstance(scheduler, dict): - for s in scheduler.values(): - if s is not None: - if config.scheduler_after_epoch: - s.last_epoch = restore_epoch - else: - s.last_epoch = restore_step - else: - if config.scheduler_after_epoch: - scheduler.last_epoch = restore_epoch - else: - scheduler.last_epoch = restore_step + if scheduler is not None and args.continue_path: + if isinstance(scheduler, list): + for s in scheduler: + if s is not None: + if config.scheduler_after_epoch: + s.last_epoch = restore_epoch + else: + s.last_epoch = restore_step + elif isinstance(scheduler, dict): + for s in scheduler.values(): + if s is not None: + if config.scheduler_after_epoch: + s.last_epoch = restore_epoch + else: + s.last_epoch = restore_step + elif config.scheduler_after_epoch: + scheduler.last_epoch = restore_epoch + else: + scheduler.last_epoch = restore_step return scheduler @staticmethod @@ -2091,9 +2080,7 @@ def get_criterion(model: nn.Module) -> nn.Module: Returns: nn.Module: Criterion layer. """ - criterion = None - criterion = model.get_criterion() - return criterion + return model.get_criterion() #################### # HELPER FUNCTIONS @@ -2148,7 +2135,6 @@ def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: def _setup_logger_config(self, log_file: str) -> None: """Set up the logger based on the process rank in DDP.""" - logger_new = logging.getLogger("trainer") handler = logging.FileHandler(log_file, mode="a") fmt = logging.Formatter("") diff --git a/trainer/utils/cuda_memory.py b/trainer/utils/cuda_memory.py index 714795c..3eba6d5 100644 --- a/trainer/utils/cuda_memory.py +++ b/trainer/utils/cuda_memory.py @@ -4,6 +4,7 @@ Helper to free Torch cuda memory and determine when a Torch exception might be because of OOM conditions. """ + from __future__ import print_function import gc @@ -82,9 +83,7 @@ def cuda_meminfo(): if not torch.cuda.is_available(): return - print( - "Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB" - ) + print("Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB") print( "Max Total:", torch.cuda.max_memory_allocated() / 2**30,