Skip to content

Commit

Permalink
Fix mistake in schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Aug 19, 2024
1 parent dafd225 commit 01d39ca
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions lighter/utils/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ArgsConfigSchema(BaseModel):
scale_batch_size: Dict[str, Any] = {}

@model_validator(mode="after")
def check_prohibited_args(self): # pylint: disable=no-self-argument
def check_prohibited_args(self):
prohibited_keys = ["model", "train_loaders", "validation_loaders", "dataloaders", "datamodule"]
for field in self.model_fields:
found_keys = [key for key in prohibited_keys if key in getattr(self, field)]
Expand Down Expand Up @@ -71,11 +71,10 @@ class MetricsSchema(BaseModel):
test: Optional[Union[Any, List[Any], Dict[str, Any]]] = None

@model_validator(mode="after")
def setup_metrics(self): # pylint: disable=no-self-argument
def setup_metrics(self):
for field in self.model_fields:
mode_metrics = getattr(self, field)
if field is not None:
setattr(self, field, MetricCollection(mode_metrics))
if getattr(self, field) is not None:
setattr(self, field, MetricCollection(getattr(self, field)))
return self


Expand Down

0 comments on commit 01d39ca

Please sign in to comment.