Skip to content

Commit

Permalink
Return pagination
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Nov 21, 2024
1 parent 64c4c97 commit f45928a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 7 deletions.
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.20.39.post11"
"version": "0.20.39.post12"
}
66 changes: 61 additions & 5 deletions hypha/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ async def list_children(
mode: str = "AND",
limit: int = 100,
order_by: str = None,
pagination: bool = False,
silent: bool = False,
user_info: self.store.login_optional = Depends(self.store.login_optional),
):
"""List child artifacts of a specified artifact."""
Expand All @@ -212,6 +214,8 @@ async def list_children(
keywords=keywords,
filters=filters,
mode=mode,
pagination=pagination,
silent=silent,
context={"user": user_info.model_dump(), "ws": workspace},
)
except KeyError:
Expand Down Expand Up @@ -1607,6 +1611,7 @@ async def search_by_vector(
limit: int = 10,
with_payload: bool = True,
with_vectors: bool = False,
pagination: bool = False,
context: dict = None,
):
user_info = UserInfo.model_validate(context["user"])
Expand Down Expand Up @@ -1635,6 +1640,16 @@ async def search_by_vector(
with_payload=with_payload,
with_vectors=with_vectors,
)
if pagination:
count = await self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
)
return {
"total": count.count,
"items": search_results,
"offset": offset,
"limit": limit,
}
return search_results
except Exception as e:
raise e
Expand All @@ -1650,6 +1665,7 @@ async def search_by_text(
limit: int = 10,
with_payload: bool = True,
with_vectors: bool = False,
pagination: bool = False,
context: dict = None,
):
user_info = UserInfo.model_validate(context["user"])
Expand All @@ -1676,6 +1692,16 @@ async def search_by_text(
with_payload=with_payload,
with_vectors=with_vectors,
)
if pagination:
count = await self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
)
return {
"total": count.count,
"items": search_results,
"offset": offset,
"limit": limit,
}
return search_results
except Exception as e:
raise e
Expand Down Expand Up @@ -1957,7 +1983,7 @@ async def list_files(
self,
artifact_id: str,
dir_path: str = None,
max_length: int = 1000,
limit: int = 1000,
version: str = None,
context: dict = None,
):
Expand Down Expand Up @@ -1999,7 +2025,7 @@ async def list_files(
s3_client,
s3_config["bucket"],
full_path,
max_length=max_length,
max_length=limit,
)
return items
except Exception as e:
Expand All @@ -2015,8 +2041,9 @@ async def list_children(
mode="AND",
offset: int = 0,
limit: int = 100,
order_by=None,
silent=False,
order_by: str = None,
silent: bool = False,
pagination: bool = False,
context: dict = None,
):
"""
Expand Down Expand Up @@ -2058,16 +2085,26 @@ async def list_children(
query = select(
*[getattr(ArtifactModel, field) for field in list_fields]
).where(ArtifactModel.parent_id == parent_artifact.id)
count_query = select(func.count()).where(
ArtifactModel.parent_id == parent_artifact.id
)
else:
# If list_fields is empty or not specified, select all columns
query = select(ArtifactModel).where(
ArtifactModel.parent_id == parent_artifact.id
)
count_query = select(func.count()).where(
ArtifactModel.parent_id == parent_artifact.id
)
else:
query = select(ArtifactModel).where(
ArtifactModel.parent_id == None,
ArtifactModel.workspace == context["ws"],
)
count_query = select(func.count()).where(
ArtifactModel.parent_id == None,
ArtifactModel.workspace == context["ws"],
)
conditions = []

# Handle keyword-based search across manifest fields
Expand Down Expand Up @@ -2189,6 +2226,7 @@ async def list_children(
)

query = query.where(stage_condition)
count_query = count_query.where(stage_condition)

# Combine conditions based on mode (AND/OR)
if conditions:
Expand All @@ -2197,6 +2235,18 @@ async def list_children(
if mode == "OR"
else query.where(and_(*conditions))
)
count_query = (
count_query.where(or_(*conditions))
if mode == "OR"
else count_query.where(and_(*conditions))
)

if pagination:
# Execute the count query
result = await session.execute(count_query)
total_count = result.scalar()
else:
total_count = None

# Pagination and ordering
order_field_map = {
Expand Down Expand Up @@ -2237,7 +2287,13 @@ async def list_children(
session, parent_artifact.id, "view_count"
)
await session.commit()

if pagination:
return {
"items": results,
"total": total_count,
"offset": offset,
"limit": limit,
}
return results

except Exception as e:
Expand Down
23 changes: 22 additions & 1 deletion hypha/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ async def search_services(
None,
description="Order by field, default is score if embedding or text_query is provided.",
),
pagination: Optional[bool] = Field(
False, description="Enable pagination, return metadata with total count."
),
context: Optional[dict] = None,
):
"""
Expand Down Expand Up @@ -928,12 +931,30 @@ async def search_services(
query, query_params=query_params
)

# Handle pagination
if pagination:
count_query = Query(query_string).paging(0, 0).dialect(2)
count_results = await self._redis.ft("service_info_index").search(
count_query, query_params=query_params
)
total_count = count_results.total
else:
total_count = None

# Convert results to dictionaries and return
services = [
ServiceInfo.from_redis_dict(vars(doc), in_bytes=False)
for doc in results.docs
]
return [service.model_dump() for service in services]
if pagination:
return {
"items": [service.model_dump() for service in services],
"total": total_count,
"offset": offset,
"limit": limit,
}
else:
return [service.model_dump() for service in services]

@schema_method
async def list_services(
Expand Down
16 changes: 16 additions & 0 deletions tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ async def test_artifact_vector_collection(
)
assert len(search_results) <= 2

results = await artifact_manager.search_by_vector(
artifact_id=vector_collection.id,
query_vector=query_vector,
limit=2,
pagination=True,
)
assert results["total"] == 3

query_filter = {
"should": None,
"min_should": None,
Expand Down Expand Up @@ -248,6 +256,14 @@ async def test_sqlite_create_and_search_artifacts(

assert len(search_results) == len(datasets)

results = await artifact_manager.list(
parent_id=collection.id,
filters={"stage": True, "manifest": {"description": "*dataset*"}},
pagination=True,
)
assert results["total"] == len(datasets)
assert len(results["items"]) == len(datasets)

# list application only
search_results = await artifact_manager.list(
parent_id=collection.id, filters={"stage": True, "type": "application"}
Expand Down
3 changes: 3 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ async def test_service_search(fastapi_server_redis_1, test_user_token):
assert "natural language processing" in services[0]["docs"]
assert services[0]["score"] < services[1]["score"]

results = await api.search_services(text_query=text_query, limit=3, pagination=True)
assert results["total"] >= 1

embedding = np.ones(384).astype(np.float32)
await api.register_service(
{
Expand Down

0 comments on commit f45928a

Please sign in to comment.