Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement __iter__ and __aiter__ in DocumentCollection #231

Merged
merged 2 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/fmu/sumo/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_case_by_uuid(self, uuid: str) -> Case:
Case: case object
"""
metadata = self._utils.get_object(uuid, _CASE_FIELDS)
return Case(self._sumo, metadata)
return Case(self._sumo, metadata, self._pit)

async def get_case_by_uuid_async(self, uuid: str) -> Case:
"""Get case object by uuid
Expand All @@ -126,7 +126,7 @@ async def get_case_by_uuid_async(self, uuid: str) -> Case:
Case: case object
"""
metadata = await self._utils.get_object_async(uuid, _CASE_FIELDS)
return Case(self._sumo, metadata)
return Case(self._sumo, metadata, self._pit)

def get_surface_by_uuid(self, uuid: str) -> Surface:
"""Get surface object by uuid
Expand Down
71 changes: 50 additions & 21 deletions src/fmu/sumo/explorer/objects/_document_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,40 @@ def __init__(
self._type = doc_type
self._sumo = sumo
self._query = self._init_query(doc_type, query)
self._pit = pit

self._pit = pit
self._new_pit_id = None
self._after = None
self._curr_index = 0
self._len = None
self._items = []
self._field_values = {}
self._query = self._init_query(doc_type, query)
self._select = select

def __iter__(self):
self._curr_index = 0
return self

def __next__(self):
if self._curr_index < self.__len__():
res = self.__getitem__(self._curr_index)
self._curr_index += 1
return res
else:
raise StopIteration

def __aiter__(self):
self._curr_index = 0
return self

async def __anext__(self):
if self._curr_index < await self.length_async():
res = await self.getitem_async(self._curr_index)
self._curr_index += 1
return res
else:
raise StopAsyncIteration

def __len__(self) -> int:
"""Get size of document collection

Expand Down Expand Up @@ -61,17 +85,14 @@ def __getitem__(self, index: int) -> Dict:
Returns:
A document at a given index
"""
if index >= self.__len__():
if index > self.__len__():
raise IndexError

if len(self._items) <= index:
while len(self._items) <= index:
prev_len = len(self._items)
self._next_batch()
curr_len = len(self._items)
while len(self._items) <= index:
hits_size = self._next_batch()

if prev_len == curr_len:
raise IndexError
if hits_size == 0:
raise IndexError

return self._items[index]

Expand All @@ -84,17 +105,14 @@ async def getitem_async(self, index: int) -> Dict:
Returns:
A document at a given index
"""
if index >= await self.length_async():
if index > await self.length_async():
raise IndexError

if len(self._items) <= index:
while len(self._items) <= index:
prev_len = len(self._items)
await self._next_batch_async()
curr_len = len(self._items)
while len(self._items) <= index:
hits_size = await self._next_batch_async()

if prev_len == curr_len:
raise IndexError
if hits_size == 0:
raise IndexError

return self._items[index]

Expand Down Expand Up @@ -162,18 +180,23 @@ def _next_batch(self) -> List[Dict]:
query["search_after"] = self._after

if self._pit is not None:
query["pit"] = self._pit.get_pit_object()
query["pit"] = self._pit.get_pit_object(self._new_pit_id)

res = self._sumo.post("/search", json=query).json()
hits = res["hits"]

if self._pit is not None:
self._new_pit_id = res["pit_id"]

if self._len is None:
self._len = hits["total"]["value"]

if len(hits["hits"]) > 0:
self._after = hits["hits"][-1]["sort"]
self._items.extend(hits["hits"])

return len(hits["hits"])

async def _next_batch_async(self) -> List[Dict]:
"""Get next batch of documents

Expand All @@ -196,10 +219,14 @@ async def _next_batch_async(self) -> List[Dict]:
query["search_after"] = self._after

if self._pit is not None:
query["pit"] = self._pit.get_pit_object()
query["pit"] = self._pit.get_pit_object(self._new_pit_id)

res = await self._sumo.post_async("/search", json=query)
hits = res.json()["hits"]
data = res.json()
hits = data["hits"]

if self._pit is not None:
self._new_pit_id = data["pit_id"]

if self._len is None:
self._len = hits["total"]["value"]
Expand All @@ -208,6 +235,8 @@ async def _next_batch_async(self) -> List[Dict]:
self._after = hits["hits"][-1]["sort"]
self._items.extend(hits["hits"])

return len(hits["hits"])

def _init_query(self, doc_type: str, query: Dict = None) -> Dict:
"""Initialize base filter for document collection

Expand Down
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/case_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __getitem__(self, index: int) -> Case:
doc = super().__getitem__(index)
return Case(self._sumo, doc, self._pit)

async def getitem_async(self, index: int) -> Case:
doc = await super().getitem_async(index)
return Case(self._sumo, doc)

def filter(
self,
uuid: Union[str, List[str]] = None,
Expand Down
8 changes: 6 additions & 2 deletions src/fmu/sumo/explorer/objects/cube_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __getitem__(self, index) -> Cube:
doc = super().__getitem__(index)
return Cube(self._sumo, doc)

async def getitem_async(self, index: int) -> Cube:
doc = await super().getitem_async(index)
return Cube(self._sumo, doc)

@property
def timestamps(self) -> List[str]:
"""List of unique timestamps in CubeCollection"""
Expand Down Expand Up @@ -129,7 +133,7 @@ def filter(
time: TimeFilter = None,
uuid: Union[str, List[str], bool] = None,
is_observation: bool = None,
is_prediction: bool = None
is_prediction: bool = None,
) -> "CubeCollection":
"""Filter cubes

Expand All @@ -156,7 +160,7 @@ def filter(
time=time,
uuid=uuid,
is_observation=is_observation,
is_prediction=is_prediction
is_prediction=is_prediction,
)

return CubeCollection(self._sumo, self._case_uuid, query, self._pit)
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/dictionary_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Dictionary:
doc = super().__getitem__(index)
return Dictionary(self._sumo, doc)

async def getitem_async(self, index: int) -> Dictionary:
doc = await super().getitem_async(index)
return Dictionary(self._sumo, doc)

def filter(
self,
name: Union[str, List[str], bool] = None,
Expand Down
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/polygons_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Polygons:
doc = super().__getitem__(index)
return Polygons(self._sumo, doc)

async def getitem_async(self, index: int) -> Polygons:
doc = await super().getitem_async(index)
return Polygons(self._sumo, doc)

def filter(
self,
name: Union[str, List[str], bool] = None,
Expand Down
10 changes: 8 additions & 2 deletions src/fmu/sumo/explorer/objects/surface_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __getitem__(self, index) -> Surface:
doc = super().__getitem__(index)
return Surface(self._sumo, doc)

async def getitem_async(self, index: int) -> Surface:
doc = await super().getitem_async(index)
return Surface(self._sumo, doc)

@property
def timestamps(self) -> List[str]:
"""List of unique timestamps in SurfaceCollection"""
Expand Down Expand Up @@ -141,7 +145,9 @@ def _aggregate(self, operation: str) -> xtgeo.RegularSurface:

async def _aggregate_async(self, operation: str) -> xtgeo.RegularSurface:
if operation not in self._aggregation_cache:
objects = await self._utils.get_objects_async(500, self._query, ["_id"])
objects = await self._utils.get_objects_async(
500, self._query, ["_id"]
)
object_ids = list(map(lambda obj: obj["_id"], objects))

res = await self._sumo.post_async(
Expand Down Expand Up @@ -291,4 +297,4 @@ def p90(self) -> xtgeo.RegularSurface:

async def p90_async(self) -> xtgeo.RegularSurface:
"""Perform a percentile aggregation"""
return await self._aggregate_async("p90")
return await self._aggregate_async("p90")
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/table_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Table:
doc = super().__getitem__(index)
return Table(self._sumo, doc)

async def getitem_async(self, index: int) -> Table:
doc = await super().getitem_async(index)
return Table(self._sumo, doc)

@property
def columns(self) -> List[str]:
"""List of unique column names"""
Expand Down
7 changes: 5 additions & 2 deletions src/fmu/sumo/explorer/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def __get_pit_id(self, keep_alive) -> str:
res = self._sumo.post("/pit", params={"keep-alive": keep_alive})
return res.json()["id"]

def get_pit_object(self) -> Dict:
def get_pit_object(self, pit_id: str = None) -> Dict:
"""Get the pit object

Returns:
Dict: dict with id and info about how long to keep alive
"""
return {"id": self._pit_id, "keep_alive": self._keep_alive}
return {
"id": pit_id if pit_id is not None else self._pit_id,
"keep_alive": self._keep_alive,
}
Loading