Skip to content

Commit

Permalink
add patch to avoid omni-us/jsonargparse#337
Browse files Browse the repository at this point in the history
  • Loading branch information
speediedan committed Aug 6, 2023
1 parent 6488afc commit e9edc88
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
13 changes: 13 additions & 0 deletions src/fts_examples/stable/cli_experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_utilities.core.imports import compare_version
from torch.utils import collect_env
Expand All @@ -26,6 +27,18 @@ def add_arguments_to_parser(self, parser):
parser.link_arguments("data.init_args.task_name", "model.init_args.task_name")


class CLIpatched_FSDPStrategy(FSDPStrategy):
def __init__(
self,
activation_checkpointing_policy: Optional[Any] = None,
auto_wrap_policy: Optional[Any] = None,
cpu_offload: Optional[Any] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)


def instantiate_class(init: Dict[str, Any], args: Optional[Union[Any, Tuple[Any, ...]]] = None) -> Any:
"""Instantiates a class with the given args and init. Accepts class definitions with a "class_path".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ trainer:
verbose: false
mode: min
strategy:
class_path: lightning.pytorch.strategies.FSDPStrategy
# temporary patch to avoid jsonargparse issue https://github.com/omni-us/jsonargparse/issues/337
class_path: fts_examples.stable.cli_experiment_utils.CLIpatched_FSDPStrategy
# class_path: lightning.pytorch.strategies.FSDPStrategy
init_args:
cpu_offload: true
# TODO: this currently may encounter jsonargparse bug https://github.com/omni-us/jsonargparse/issues/337
activation_checkpointing_policy:
class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ trainer:
verbose: false
mode: min
strategy:
class_path: lightning.pytorch.strategies.FSDPStrategy
# temporary patch to avoid jsonargparse issue https://github.com/omni-us/jsonargparse/issues/337
class_path: fts_examples.stable.cli_experiment_utils.CLIpatched_FSDPStrategy
#class_path: lightning.pytorch.strategies.FSDPStrategy
init_args:
cpu_offload: false
# TODO: this currently may encounter jsonargparse bug https://github.com/omni-us/jsonargparse/issues/337
activation_checkpointing_policy:
class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ trainer:
verbose: false
mode: min
strategy:
class_path: lightning.pytorch.strategies.FSDPStrategy
# temporary patch to avoid jsonargparse issue https://github.com/omni-us/jsonargparse/issues/337
class_path: fts_examples.stable.cli_experiment_utils.CLIpatched_FSDPStrategy
# class_path: lightning.pytorch.strategies.FSDPStrategy
init_args:
cpu_offload: false
# activation_checkpointing_policy: !!set
# ? transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Layer
# TODO: this currently may encounter jsonargparse bug https://github.com/omni-us/jsonargparse/issues/337
activation_checkpointing_policy:
class_path: torch.distributed.fsdp.wrap.ModuleWrapPolicy
init_args:
Expand Down

0 comments on commit e9edc88

Please sign in to comment.