Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make checkpoint loading more informative. Remove incorrect Metric type check. Make TableWriter expect a path instead of dir. #126

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,12 @@ class LighterBaseWriter(ABC, Callback):
2) `self.write()` method to specify the saving strategy for a prediction.

Args:
directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside.
path (Union[str, Path]): Path for saving. It can be a directory or a specific file.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function.
"""

def __init__(self, directory: str, writer: Union[str, Callable]) -> None:
"""
Initialize the LighterBaseWriter.

Args:
directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function.
"""
self.directory = Path(directory)
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
self.path = Path(path)

# Check if the writer is a string and if it exists in the writers dictionary
if isinstance(writer, str):
Expand Down Expand Up @@ -70,30 +63,34 @@ def write(self, tensor: torch.Tensor, id: int) -> None:

def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
"""
Callback function to set up necessary prerequisites: prediction count and prediction directory.
Callback function to set up necessary prerequisites: prediction count and prediction file or directory.
When executing in a distributed environment, it ensures that:
1. Each distributed node initializes a prediction count based on its rank.
2. All distributed nodes write predictions to the same directory.
3. The directory is accessible to all nodes, i.e., all nodes share the same storage.
2. All distributed nodes write predictions to the same path.
3. The path is accessible to all nodes, i.e., all nodes share the same storage.
"""
if stage != "predict":
return

# Initialize the prediction count with the rank of the current process
self._pred_counter = torch.distributed.get_rank() if trainer.world_size > 1 else 0

# Ensure all distributed nodes write to the same directory
self.directory = trainer.strategy.broadcast(self.directory, src=0)
# Warn if the directory already exists
if self.directory.exists():
logger.warning(f"{self.directory} already exists, existing predictions will be overwritten.")
# Ensure all distributed nodes write to the same path
self.path = trainer.strategy.broadcast(self.path, src=0)
directory = self.path.parent if self.path.suffix else self.path

# Warn if the path already exists
if self.path.exists():
logger.warning(f"{self.path} already exists, existing predictions will be overwritten.")

if trainer.is_global_zero:
self.directory.mkdir(parents=True, exist_ok=True)
directory.mkdir(parents=True, exist_ok=True)

# Wait for rank 0 to create the directory
trainer.strategy.barrier()

# Ensure all distributed nodes have access to the directory
if not self.directory.exists():
# Ensure all distributed nodes have access to the path
if not directory.exists():
ibro45 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access."
)
Expand Down
10 changes: 4 additions & 6 deletions lighter/callbacks/writer/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ class LighterTableWriter(LighterBaseWriter):
Writer for saving predictions in a table format.

Args:
directory (Path): Directory where the CSV will be saved.
path (Path): CSV filepath.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function.
Available writers: "tensor". A custom writer function must take a single argument: `tensor`, and return the record
to be saved in the CSV file. The tensor will be a single tensor without the batch dimension.
"""

def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(directory, writer)
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(path, writer)
self.csv_records = {}

@property
Expand Down Expand Up @@ -52,8 +52,6 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No
If training was done in a distributed setting, it gathers predictions from all processes
and then saves them from the rank 0 process.
"""
csv_path = self.directory / "predictions.csv"

# Sort the records by ID and convert the dictionary to a list
self.csv_records = [self.csv_records[id] for id in sorted(self.csv_records)]

Expand All @@ -69,4 +67,4 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No

# Save the records to a CSV file
if trainer.is_global_zero:
pd.DataFrame(self.csv_records).to_csv(csv_path)
pd.DataFrame(self.csv_records).to_csv(self.path)
2 changes: 0 additions & 2 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ def _log_stats(
# Metrics
if metrics is not None:
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise TypeError(f"Expected type for metric is 'Metric', got '{type(metric).__name__}' instead.")
on_step_log(f"{mode}/metrics/{name}/step", metric)
on_epoch_log(f"{mode}/metrics/{name}/epoch", metric)
# Optimizer's lr, momentum, beta. Logged in train mode and once per epoch.
Expand Down
11 changes: 7 additions & 4 deletions lighter/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ def adjust_prefix_and_load_state_dict(
# Add the model_prefix before the current key name if there's no specific ckpt_prefix
ckpt = {f"{model_prefix}{key}": value for key, value in ckpt.items() if ckpt_prefix in key}

# Check if there is no overlap between the checkpoint's and model's state_dict.
if not set(ckpt.keys()) & set(model.state_dict().keys()):
# Check if the checkpoint's and model's state_dicts have no overlap.
model_keys = list(model.state_dict().keys())
ckpt_keys = list(ckpt.keys())
if not set(ckpt_keys) & set(model_keys):
raise ValueError(
"There is no overlap between checkpoint's and model's state_dict. Check their "
"`state_dict` keys and adjust accordingly using `ckpt_prefix` and `model_prefix`."
"There is no overlap between checkpoint's and model's state_dict."
f"\nModel keys: '{model_keys[0]}', ..., '{model_keys[-1]}', "
f"\nCheckpoint keys: '{ckpt_keys[0]}', ..., '{ckpt_keys[-1]}'"
)

# Remove the layers that are not to be loaded.
Expand Down
Loading