Skip to content

Commit

Permalink
RUST-1945 Add a with_type method to the Aggregate action (#1100)
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelatkinson authored May 13, 2024
1 parent e06744b commit 31a0750
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 14 deletions.
72 changes: 58 additions & 14 deletions src/action/aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Duration;
use std::{marker::PhantomData, time::Duration};

use bson::Document;

Expand All @@ -24,15 +24,17 @@ impl Database {
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
/// information on aggregations.
///
/// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
/// a `ClientSession` is provided.
/// `await` will return d[`Result<Cursor<Document>>`]. If a [`ClientSession`] was provided, the
/// returned cursor will be a [`SessionCursor`]. If [`with_type`](Aggregate::with_type) was
/// called, the returned cursor will be generic over the `T` specified.
#[deeplink]
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
Aggregate {
target: AggregateTargetRef::Database(self),
pipeline: pipeline.into_iter().collect(),
options: None,
session: ImplicitSession,
_phantom: PhantomData,
}
}
}
Expand All @@ -46,15 +48,17 @@ where
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
/// information on aggregations.
///
/// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
/// a [`ClientSession`] is provided.
/// `await` will return d[`Result<Cursor<Document>>`]. If a [`ClientSession`] was provided, the
/// returned cursor will be a [`SessionCursor`]. If [`with_type`](Aggregate::with_type) was
/// called, the returned cursor will be generic over the `T` specified.
#[deeplink]
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
Aggregate {
target: AggregateTargetRef::Collection(CollRef::new(self)),
pipeline: pipeline.into_iter().collect(),
options: None,
session: ImplicitSession,
_phantom: PhantomData,
}
}
}
Expand All @@ -66,8 +70,10 @@ impl crate::sync::Database {
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
/// information on aggregations.
///
/// [`run`](Aggregate::run) will return d[`Result<crate::sync::Cursor<Document>>`] or
/// d[`Result<crate::sync::SessionCursor<Document>>`] if a [`ClientSession`] is provided.
/// [`run`](Aggregate::run) will return d[Result<crate::sync::Cursor<Document>>`]. If a
/// [`crate::sync::ClientSession`] was provided, the returned cursor will be a
/// [`crate::sync::SessionCursor`]. If [`with_type`](Aggregate::with_type) was called, the
/// returned cursor will be generic over the `T` specified.
#[deeplink]
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
self.async_database.aggregate(pipeline)
Expand All @@ -84,8 +90,10 @@ where
/// See the documentation [here](https://www.mongodb.com/docs/manual/aggregation/) for more
/// information on aggregations.
///
/// [`run`](Aggregate::run) will return d[`Result<crate::sync::Cursor<Document>>`] or
/// d[`Result<crate::sync::SessionCursor<Document>>`] if a `ClientSession` is provided.
/// [`run`](Aggregate::run) will return d[Result<crate::sync::Cursor<Document>>`]. If a
/// `crate::sync::ClientSession` was provided, the returned cursor will be a
/// `crate::sync::SessionCursor`. If [`with_type`](Aggregate::with_type) was called, the
/// returned cursor will be generic over the `T` specified.
#[deeplink]
pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
self.async_collection.aggregate(pipeline)
Expand All @@ -95,14 +103,15 @@ where
/// Run an aggregation operation. Construct with [`Database::aggregate`] or
/// [`Collection::aggregate`].
#[must_use]
pub struct Aggregate<'a, Session = ImplicitSession> {
pub struct Aggregate<'a, Session = ImplicitSession, T = Document> {
target: AggregateTargetRef<'a>,
pipeline: Vec<Document>,
options: Option<AggregateOptions>,
session: Session,
_phantom: PhantomData<T>,
}

impl<'a, Session> Aggregate<'a, Session> {
impl<'a, Session, T> Aggregate<'a, Session, T> {
option_setters!(options: AggregateOptions;
allow_disk_use: bool,
batch_size: u32,
Expand Down Expand Up @@ -130,15 +139,50 @@ impl<'a> Aggregate<'a, ImplicitSession> {
pipeline: self.pipeline,
options: self.options,
session: ExplicitSession(value.into()),
_phantom: PhantomData,
}
}
}

#[action_impl(sync = crate::sync::Cursor<Document>)]
impl<'a> Action for Aggregate<'a, ImplicitSession> {
impl<'a, Session> Aggregate<'a, Session, Document> {
/// Use the provided type for the returned cursor.
///
/// ```rust
/// # use futures_util::TryStreamExt;
/// # use mongodb::{bson::Document, error::Result, Cursor, Database};
/// # use serde::Deserialize;
/// # async fn run() -> Result<()> {
/// # let database: Database = todo!();
/// # let pipeline: Vec<Document> = todo!();
/// #[derive(Deserialize)]
/// struct PipelineOutput {
/// len: usize,
/// }
///
/// let aggregate_cursor = database
/// .aggregate(pipeline)
/// .with_type::<PipelineOutput>()
/// .await?;
/// let aggregate_results: Vec<PipelineOutput> = aggregate_cursor.try_collect().await?;
/// # Ok(())
/// # }
/// ```
pub fn with_type<T>(self) -> Aggregate<'a, Session, T> {
Aggregate {
target: self.target,
pipeline: self.pipeline,
options: self.options,
session: self.session,
_phantom: PhantomData,
}
}
}

#[action_impl(sync = crate::sync::Cursor<T>)]
impl<'a, T> Action for Aggregate<'a, ImplicitSession, T> {
type Future = AggregateFuture;

async fn execute(mut self) -> Result<Cursor<Document>> {
async fn execute(mut self) -> Result<Cursor<T>> {
resolve_options!(
self.target,
self.options,
Expand Down
43 changes: 43 additions & 0 deletions src/test/coll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::{
results::DeleteResult,
test::{get_client_options, log_uncaptured, util::TestClient, EventClient},
Collection,
Cursor,
IndexModel,
};

Expand Down Expand Up @@ -1306,3 +1307,45 @@ async fn insert_many_document_sequences() {
let second_batch_len = second_event.command.get_array("documents").unwrap().len();
assert_eq!(first_batch_len + second_batch_len, total_docs);
}

#[tokio::test]
async fn aggregate_with_generics() {
#[derive(Serialize)]
struct A {
str: String,
}

#[derive(Deserialize)]
struct B {
len: i32,
}

let client = TestClient::new().await;
let collection = client
.database("aggregate_with_generics")
.collection::<A>("aggregate_with_generics");

let a = A {
str: "hi".to_string(),
};
let len = a.str.len();
collection.insert_one(&a).await.unwrap();

// Assert at compile-time that the default cursor returned is a Cursor<Document>
let basic_pipeline = vec![doc! { "$match": { "a": 1 } }];
let _: Cursor<Document> = collection.aggregate(basic_pipeline).await.unwrap();

// Assert that data is properly deserialized when using with_type
let project_pipeline = vec![doc! { "$project": {
"str": 1,
"len": { "$strLenBytes": "$str" }
}
}];
let cursor = collection
.aggregate(project_pipeline)
.with_type::<B>()
.await
.unwrap();
let lens: Vec<B> = cursor.try_collect().await.unwrap();
assert_eq!(lens[0].len as usize, len);
}
31 changes: 31 additions & 0 deletions src/test/db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::cmp::Ord;

use futures::stream::TryStreamExt;
use serde::Deserialize;

use crate::{
action::Action,
Expand All @@ -17,6 +18,7 @@ use crate::{
results::{CollectionSpecification, CollectionType},
test::util::TestClient,
Client,
Cursor,
Database,
};

Expand Down Expand Up @@ -378,3 +380,32 @@ async fn clustered_index_list_collections() {
.unwrap();
assert!(clustered_index_collection.options.clustered_index.is_some());
}

#[tokio::test]
async fn aggregate_with_generics() {
#[derive(Deserialize)]
struct A {
str: String,
}

let client = TestClient::new().await;
let database = client.database("aggregate_with_generics");

if client.server_version_lt(5, 1) {
log_uncaptured(
"skipping aggregate_with_generics: $documents agg stage only available on 5.1+",
);
return;
}

// The cursor returned will contain these documents
let pipeline = vec![doc! { "$documents": [ { "str": "hi" } ] }];

// Assert at compile-time that the default cursor returned is a Cursor<Document>
let _: Cursor<Document> = database.aggregate(pipeline.clone()).await.unwrap();

// Assert that data is properly deserialized when using with_type
let mut cursor = database.aggregate(pipeline).with_type::<A>().await.unwrap();
assert!(cursor.advance().await.unwrap());
assert_eq!(&cursor.deserialize_current().unwrap().str, "hi");
}

0 comments on commit 31a0750

Please sign in to comment.