Skip to content

Commit

Permalink
add DatasetList dataclass to extract only needed fields from metastore (
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Nov 27, 2024
1 parent f759eff commit 3bd22ad
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 31 deletions.
7 changes: 4 additions & 3 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DATASET_PREFIX,
QUERY_DATASET_PREFIX,
DatasetDependency,
DatasetListRecord,
DatasetRecord,
DatasetStats,
DatasetStatus,
Expand Down Expand Up @@ -72,7 +73,7 @@
AbstractMetastore,
AbstractWarehouse,
)
from datachain.dataset import DatasetVersion
from datachain.dataset import DatasetListVersion
from datachain.job import Job
from datachain.lib.file import File
from datachain.listing import Listing
Expand Down Expand Up @@ -1135,7 +1136,7 @@ def get_dataset_dependencies(

return direct_dependencies

def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]:
def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetListRecord]:
datasets = self.metastore.list_datasets()
for d in datasets:
if not d.is_bucket_listing or include_listing:
Expand All @@ -1144,7 +1145,7 @@ def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]:
def list_datasets_versions(
self,
include_listing: bool = False,
) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
"""Iterate over all dataset versions with related jobs."""
datasets = list(self.ls_datasets(include_listing=include_listing))

Expand Down
74 changes: 63 additions & 11 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from datachain.data_storage.serializer import Serializable
from datachain.dataset import (
DatasetDependency,
DatasetListRecord,
DatasetListVersion,
DatasetRecord,
DatasetStatus,
DatasetVersion,
Expand Down Expand Up @@ -59,6 +61,8 @@ class AbstractMetastore(ABC, Serializable):

schema: "schema.Schema"
dataset_class: type[DatasetRecord] = DatasetRecord
dataset_list_class: type[DatasetListRecord] = DatasetListRecord
dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
dependency_class: type[DatasetDependency] = DatasetDependency
job_class: type[Job] = Job

Expand Down Expand Up @@ -166,11 +170,11 @@ def remove_dataset_version(
"""

@abstractmethod
def list_datasets(self) -> Iterator[DatasetRecord]:
def list_datasets(self) -> Iterator[DatasetListRecord]:
"""Lists all datasets."""

@abstractmethod
def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetRecord"]:
def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
"""Lists all datasets which names start with prefix."""

@abstractmethod
Expand Down Expand Up @@ -348,6 +352,14 @@ def _dataset_fields(self) -> list[str]:
if c.name # type: ignore [attr-defined]
]

@cached_property
def _dataset_list_fields(self) -> list[str]:
return [
c.name # type: ignore [attr-defined]
for c in self._datasets_columns()
if c.name in self.dataset_list_class.__dataclass_fields__ # type: ignore [attr-defined]
]

@classmethod
def _datasets_versions_columns(cls) -> list["SchemaItem"]:
"""Datasets versions table columns."""
Expand Down Expand Up @@ -390,6 +402,15 @@ def _dataset_version_fields(self) -> list[str]:
if c.name # type: ignore [attr-defined]
]

@cached_property
def _dataset_list_version_fields(self) -> list[str]:
return [
c.name # type: ignore [attr-defined]
for c in self._datasets_versions_columns()
if c.name # type: ignore [attr-defined]
in self.dataset_list_version_class.__dataclass_fields__
]

@classmethod
def _datasets_dependencies_columns(cls) -> list["SchemaItem"]:
"""Datasets dependencies table columns."""
Expand Down Expand Up @@ -671,7 +692,25 @@ def _parse_datasets(self, rows) -> Iterator["DatasetRecord"]:
if dataset:
yield dataset

def _base_dataset_query(self):
def _parse_list_dataset(self, rows) -> Optional[DatasetListRecord]:
versions = [self.dataset_list_class.parse(*r) for r in rows]
if not versions:
return None
return reduce(lambda ds, version: ds.merge_versions(version), versions)

def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
# grouping rows by dataset id
for _, g in groupby(rows, lambda r: r[0]):
dataset = self._parse_list_dataset(list(g))
if dataset:
yield dataset

def _get_dataset_query(
self,
dataset_fields: list[str],
dataset_version_fields: list[str],
isouter: bool = True,
):
if not (
self.db.has_table(self._datasets.name)
and self.db.has_table(self._datasets_versions.name)
Expand All @@ -680,23 +719,36 @@ def _base_dataset_query(self):

d = self._datasets
dv = self._datasets_versions

query = self._datasets_select(
*(getattr(d.c, f) for f in self._dataset_fields),
*(getattr(dv.c, f) for f in self._dataset_version_fields),
*(getattr(d.c, f) for f in dataset_fields),
*(getattr(dv.c, f) for f in dataset_version_fields),
)
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=True)
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
return query.select_from(j)

def list_datasets(self) -> Iterator["DatasetRecord"]:
def _base_dataset_query(self):
return self._get_dataset_query(
self._dataset_fields, self._dataset_version_fields
)

def _base_list_datasets_query(self):
return self._get_dataset_query(
self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
)

def list_datasets(self) -> Iterator["DatasetListRecord"]:
"""Lists all datasets."""
yield from self._parse_datasets(self.db.execute(self._base_dataset_query()))
yield from self._parse_dataset_list(
self.db.execute(self._base_list_datasets_query())
)

def list_datasets_by_prefix(
self, prefix: str, conn=None
) -> Iterator["DatasetRecord"]:
query = self._base_dataset_query()
) -> Iterator["DatasetListRecord"]:
query = self._base_list_datasets_query()
query = query.where(self._datasets.c.name.startswith(prefix))
yield from self._parse_datasets(self.db.execute(query))
yield from self._parse_dataset_list(self.db.execute(query))

def get_dataset(self, name: str, conn=None) -> DatasetRecord:
"""Gets a single dataset by name"""
Expand Down
156 changes: 142 additions & 14 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from datachain.sql.types import NAME_TYPES_MAPPING, SQLType

T = TypeVar("T", bound="DatasetRecord")
LT = TypeVar("LT", bound="DatasetListRecord")
V = TypeVar("V", bound="DatasetVersion")
LV = TypeVar("LV", bound="DatasetListVersion")
DD = TypeVar("DD", bound="DatasetDependency")

DATASET_PREFIX = "ds://"
Expand Down Expand Up @@ -264,6 +266,59 @@ def from_dict(cls, d: dict[str, Any]) -> "DatasetVersion":
return cls(**kwargs)


@dataclass
class DatasetListVersion:
id: int
uuid: str
dataset_id: int
version: int
status: int
created_at: datetime
finished_at: Optional[datetime]
error_message: str
error_stack: str
num_objects: Optional[int]
size: Optional[int]
query_script: str = ""
job_id: Optional[str] = None

@classmethod
def parse(
cls: type[LV],
id: int,
uuid: str,
dataset_id: int,
version: int,
status: int,
created_at: datetime,
finished_at: Optional[datetime],
error_message: str,
error_stack: str,
num_objects: Optional[int],
size: Optional[int],
query_script: str = "",
job_id: Optional[str] = None,
):
return cls(
id,
uuid,
dataset_id,
version,
status,
created_at,
finished_at,
error_message,
error_stack,
num_objects,
size,
query_script,
job_id,
)

def __hash__(self):
return hash(f"{self.dataset_id}_{self.version}")


@dataclass
class DatasetRecord:
id: int
Expand Down Expand Up @@ -447,20 +502,6 @@ def uri(self, version: int) -> str:
identifier = self.identifier(version)
return f"{DATASET_PREFIX}{identifier}"

@property
def is_bucket_listing(self) -> bool:
"""
For bucket listing we implicitly create underlying dataset to hold data. This
method is checking if this is one of those datasets.
"""
from datachain.client import Client

# TODO refactor and maybe remove method in
# https://github.com/iterative/datachain/issues/318
return Client.is_data_source_uri(self.name) or self.name.startswith(
LISTING_PREFIX
)

@property
def versions_values(self) -> list[int]:
"""
Expand Down Expand Up @@ -499,5 +540,92 @@ def from_dict(cls, d: dict[str, Any]) -> "DatasetRecord":
return cls(**kwargs, versions=versions)


@dataclass
class DatasetListRecord:
id: int
name: str
description: Optional[str]
labels: list[str]
versions: list[DatasetListVersion]
created_at: Optional[datetime] = None

@classmethod
def parse( # noqa: PLR0913
cls: type[LT],
id: int,
name: str,
description: Optional[str],
labels: str,
created_at: datetime,
version_id: int,
version_uuid: str,
version_dataset_id: int,
version: int,
version_status: int,
version_created_at: datetime,
version_finished_at: Optional[datetime],
version_error_message: str,
version_error_stack: str,
version_num_objects: Optional[int],
version_size: Optional[int],
version_query_script: Optional[str],
version_job_id: Optional[str] = None,
) -> "DatasetListRecord":
labels_lst: list[str] = json.loads(labels) if labels else []

dataset_version = DatasetListVersion.parse(
version_id,
version_uuid,
version_dataset_id,
version,
version_status,
version_created_at,
version_finished_at,
version_error_message,
version_error_stack,
version_num_objects,
version_size,
version_query_script, # type: ignore[arg-type]
version_job_id,
)

return cls(
id,
name,
description,
labels_lst,
[dataset_version],
created_at,
)

def merge_versions(self, other: "DatasetListRecord") -> "DatasetListRecord":
"""Merge versions from another dataset"""
if other.id != self.id:
raise RuntimeError("Cannot merge versions of datasets with different ids")
if not other.versions:
# nothing to merge
return self
if not self.versions:
self.versions = []

self.versions = list(set(self.versions + other.versions))
self.versions.sort(key=lambda v: v.version)
return self

@property
def is_bucket_listing(self) -> bool:
"""
For bucket listing we implicitly create underlying dataset to hold data. This
method is checking if this is one of those datasets.
"""
from datachain.client import Client

# TODO refactor and maybe remove method in
# https://github.com/iterative/datachain/issues/318
return Client.is_data_source_uri(self.name) or self.name.startswith(
LISTING_PREFIX
)


class RowDict(dict):
pass
10 changes: 7 additions & 3 deletions src/datachain/lib/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from pydantic import Field, field_validator

from datachain.dataset import DatasetRecord, DatasetStatus, DatasetVersion
from datachain.dataset import (
DatasetListRecord,
DatasetListVersion,
DatasetStatus,
)
from datachain.job import Job
from datachain.lib.data_model import DataModel
from datachain.utils import TIME_ZERO
Expand Down Expand Up @@ -57,8 +61,8 @@ def validate_metrics(cls, v):
@classmethod
def from_models(
cls,
dataset: DatasetRecord,
version: DatasetVersion,
dataset: DatasetListRecord,
version: DatasetListVersion,
job: Optional[Job],
) -> "Self":
return cls(
Expand Down
Loading

0 comments on commit 3bd22ad

Please sign in to comment.