Skip to content

Commit

Permalink
RUST-1222 Cancel in-progress operations when SDAM heartbeats time out (
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelatkinson authored Nov 22, 2024
1 parent 450c8a3 commit e3df089
Show file tree
Hide file tree
Showing 57 changed files with 1,034 additions and 503 deletions.
4 changes: 2 additions & 2 deletions src/client/auth/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async fn authenticate_stream_inner(
);
let client_first = sasl_start.into_command();

let server_first_response = conn.send_command(client_first, None).await?;
let server_first_response = conn.send_message(client_first).await?;

let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?;
server_first.validate(&nonce)?;
Expand Down Expand Up @@ -135,7 +135,7 @@ async fn authenticate_stream_inner(

let client_second = sasl_continue.into_command();

let server_second_response = conn.send_command(client_second, None).await?;
let server_second_response = conn.send_message(client_second).await?;
let server_second = SaslResponse::parse(
MECH_NAME,
server_second_response.auth_response_body(MECH_NAME)?,
Expand Down
2 changes: 1 addition & 1 deletion src/client/auth/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ async fn send_sasl_command(
conn: &mut Connection,
command: crate::cmap::Command,
) -> Result<SaslResponse> {
let response = conn.send_command(command, None).await?;
let response = conn.send_message(command).await?;
SaslResponse::parse(
MONGODB_OIDC_STR,
response.auth_response_body(MONGODB_OIDC_STR)?,
Expand Down
2 changes: 1 addition & 1 deletion src/client/auth/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub(crate) async fn authenticate_stream(
)
.into_command();

let response = conn.send_command(sasl_start, None).await?;
let response = conn.send_message(sasl_start).await?;
let sasl_response = SaslResponse::parse("PLAIN", response.auth_response_body("PLAIN")?)?;

if !sasl_response.done {
Expand Down
6 changes: 3 additions & 3 deletions src/client/auth/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl ScramVersion {

let command = client_first.to_command(self);

let server_first = conn.send_command(command, None).await?;
let server_first = conn.send_message(command).await?;

Ok(FirstRound {
client_first,
Expand Down Expand Up @@ -215,7 +215,7 @@ impl ScramVersion {

let command = client_final.to_command();

let server_final_response = conn.send_command(command, None).await?;
let server_final_response = conn.send_message(command).await?;
let server_final = ServerFinal::parse(server_final_response.auth_response_body("SCRAM")?)?;
server_final.validate(salted_password.as_slice(), &client_final, self)?;

Expand All @@ -231,7 +231,7 @@ impl ScramVersion {
);
let command = noop.into_command();

let server_noop_response = conn.send_command(command, None).await?;
let server_noop_response = conn.send_message(command).await?;
let server_noop_response_document: Document =
server_noop_response.auth_response_body("SCRAM")?;

Expand Down
2 changes: 1 addition & 1 deletion src/client/auth/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(crate) async fn send_client_first(
) -> Result<RawCommandResponse> {
let command = build_client_first(credential, server_api);

conn.send_command(command, None).await
conn.send_message(command).await
}

/// Performs X.509 authentication for a given stream.
Expand Down
7 changes: 3 additions & 4 deletions src/client/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,12 @@ impl Client {
}

let should_redact = cmd.should_redact();
let should_compress = cmd.should_compress();

let cmd_name = cmd.name.clone();
let target_db = cmd.target_db.clone();

#[allow(unused_mut)]
let mut message = Message::from_command(cmd, Some(request_id))?;
let mut message = Message::try_from(cmd)?;
message.request_id = Some(request_id);
#[cfg(feature = "in-use-encryption")]
{
let guard = self.inner.csfle.read().await;
Expand Down Expand Up @@ -652,7 +651,7 @@ impl Client {
.await;

let start_time = Instant::now();
let command_result = match connection.send_message(message, should_compress).await {
let command_result = match connection.send_message(message).await {
Ok(response) => {
async fn handle_response<T: Operation>(
client: &Client,
Expand Down
62 changes: 42 additions & 20 deletions src/cmap/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ use derive_where::derive_where;
use serde::Serialize;
use tokio::{
io::BufStream,
sync::{mpsc, Mutex},
sync::{
broadcast::{self, error::RecvError},
mpsc,
Mutex,
},
};

use self::wire::{Message, MessageFlags};
Expand Down Expand Up @@ -171,12 +175,44 @@ impl Connection {
self.error.is_some()
}

pub(crate) async fn send_message_with_cancellation(
&mut self,
message: impl TryInto<Message, Error = impl Into<Error>>,
cancellation_receiver: &mut broadcast::Receiver<()>,
) -> Result<RawCommandResponse> {
tokio::select! {
biased;

// A lagged error indicates that more heartbeats failed than the channel's capacity
// between checking out this connection and executing the operation. If this occurs,
// then proceed with cancelling the operation. RecvError::Closed can be ignored, as
// the sender (and by extension the connection pool) dropping does not indicate that
// the operation should be cancelled.
Ok(_) | Err(RecvError::Lagged(_)) = cancellation_receiver.recv() => {
let error: Error = ErrorKind::ConnectionPoolCleared {
message: format!(
"Connection to {} interrupted due to server monitor timeout",
self.address,
)
}.into();
self.error = Some(error.clone());
Err(error)
}
// This future is not cancellation safe because it contains calls to methods that are
// not cancellation safe (e.g. AsyncReadExt::read_exact). However, in the case that
// this future is cancelled because a cancellation message was received, this
// connection will be closed upon being returned to the pool, so any data loss on its
// underlying stream is not an issue.
result = self.send_message(message) => result,
}
}

pub(crate) async fn send_message(
&mut self,
message: Message,
// This value is only read if a compression feature flag is enabled.
#[allow(unused_variables)] can_compress: bool,
message: impl TryInto<Message, Error = impl Into<Error>>,
) -> Result<RawCommandResponse> {
let message = message.try_into().map_err(Into::into)?;

if self.more_to_come {
return Err(Error::internal(format!(
"attempted to send a new message to {} but moreToCome bit was set",
Expand All @@ -192,7 +228,7 @@ impl Connection {
feature = "snappy-compression"
))]
let write_result = match self.compressor {
Some(ref compressor) if can_compress => {
Some(ref compressor) if message.should_compress => {
message
.write_op_compressed_to(&mut self.stream, compressor)
.await
Expand Down Expand Up @@ -232,21 +268,6 @@ impl Connection {
))
}

/// Executes a `Command` and returns a `CommandResponse` containing the result from the server.
///
/// An `Ok(...)` result simply means the server received the command and that the driver
/// driver received the response; it does not imply anything about the success of the command
/// itself.
pub(crate) async fn send_command(
&mut self,
command: Command,
request_id: impl Into<Option<i32>>,
) -> Result<RawCommandResponse> {
let to_compress = command.should_compress();
let message = Message::from_command(command, request_id.into())?;
self.send_message(message, to_compress).await
}

/// Receive the next message from the connection.
/// This will return an error if the previous response on this connection did not include the
/// moreToCome flag.
Expand Down Expand Up @@ -378,6 +399,7 @@ pub(crate) struct PendingConnection {
pub(crate) generation: PoolGeneration,
pub(crate) event_emitter: CmapEventEmitter,
pub(crate) time_created: Instant,
pub(crate) cancellation_receiver: Option<broadcast::Receiver<()>>,
}

impl PendingConnection {
Expand Down
65 changes: 51 additions & 14 deletions src/cmap/conn/pooled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ use std::{
};

use derive_where::derive_where;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{broadcast, mpsc, Mutex};

use super::{
CmapEventEmitter,
Connection,
ConnectionGeneration,
ConnectionInfo,
Message,
PendingConnection,
PinnedConnectionHandle,
PoolManager,
RawCommandResponse,
};
use crate::{
bson::oid::ObjectId,
Expand Down Expand Up @@ -50,7 +52,7 @@ pub(crate) struct PooledConnection {
}

/// The state of a pooled connection.
#[derive(Clone, Debug)]
#[derive(Debug)]
enum PooledConnectionState {
/// The state associated with a connection checked into the connection pool.
CheckedIn { available_time: Instant },
Expand All @@ -59,6 +61,10 @@ enum PooledConnectionState {
CheckedOut {
/// The manager used to check this connection back into the pool.
pool_manager: PoolManager,

/// The receiver to receive a cancellation notice. Only present on non-load-balanced
/// connections.
cancellation_receiver: Option<broadcast::Receiver<()>>,
},

/// The state associated with a pinned connection.
Expand Down Expand Up @@ -140,6 +146,24 @@ impl PooledConnection {
.and_then(|sd| sd.service_id)
}

/// Sends a message on this connection.
pub(crate) async fn send_message(
&mut self,
message: impl TryInto<Message, Error = impl Into<Error>>,
) -> Result<RawCommandResponse> {
match self.state {
PooledConnectionState::CheckedOut {
cancellation_receiver: Some(ref mut cancellation_receiver),
..
} => {
self.connection
.send_message_with_cancellation(message, cancellation_receiver)
.await
}
_ => self.connection.send_message(message).await,
}
}

/// Updates the state of the connection to indicate that it is checked into the pool.
pub(crate) fn mark_checked_in(&mut self) {
if !matches!(self.state, PooledConnectionState::CheckedIn { .. }) {
Expand All @@ -155,8 +179,15 @@ impl PooledConnection {
}

/// Updates the state of the connection to indicate that it is checked out of the pool.
pub(crate) fn mark_checked_out(&mut self, pool_manager: PoolManager) {
self.state = PooledConnectionState::CheckedOut { pool_manager };
pub(crate) fn mark_checked_out(
&mut self,
pool_manager: PoolManager,
cancellation_receiver: Option<broadcast::Receiver<()>>,
) {
self.state = PooledConnectionState::CheckedOut {
pool_manager,
cancellation_receiver,
};
}

/// Whether this connection is idle.
Expand All @@ -175,15 +206,14 @@ impl PooledConnection {
Instant::now().duration_since(available_time) >= max_idle_time
}

/// Nullifies the internal state of this connection and returns it in a new [PooledConnection].
/// If a state is provided, then the new connection will contain that state; otherwise, this
/// connection's state will be cloned.
fn take(&mut self, state: impl Into<Option<PooledConnectionState>>) -> Self {
/// Nullifies the internal state of this connection and returns it in a new [PooledConnection]
/// with the given state.
fn take(&mut self, new_state: PooledConnectionState) -> Self {
Self {
connection: self.connection.take(),
generation: self.generation,
event_emitter: self.event_emitter.clone(),
state: state.into().unwrap_or_else(|| self.state.clone()),
state: new_state,
}
}

Expand All @@ -196,7 +226,9 @@ impl PooledConnection {
self.id
)))
}
PooledConnectionState::CheckedOut { ref pool_manager } => {
PooledConnectionState::CheckedOut {
ref pool_manager, ..
} => {
let (tx, rx) = mpsc::channel(1);
self.state = PooledConnectionState::Pinned {
// Mark the connection as in-use while the operation currently using the
Expand Down Expand Up @@ -286,10 +318,11 @@ impl Drop for PooledConnection {
// Nothing needs to be done when a checked-in connection is dropped.
PooledConnectionState::CheckedIn { .. } => Ok(()),
// A checked-out connection should be sent back to the connection pool.
PooledConnectionState::CheckedOut { pool_manager } => {
PooledConnectionState::CheckedOut { pool_manager, .. } => {
let pool_manager = pool_manager.clone();
let mut dropped_connection = self.take(None);
dropped_connection.mark_checked_in();
let dropped_connection = self.take(PooledConnectionState::CheckedIn {
available_time: Instant::now(),
});
pool_manager.check_in(dropped_connection)
}
// A pinned connection should be returned to its pinner or to the connection pool.
Expand Down Expand Up @@ -339,7 +372,11 @@ impl Drop for PooledConnection {
}
// The pinner of this connection has been dropped while the connection was
// sitting in its channel, so the connection should be returned to the pool.
PinnedState::Returned { .. } => pool_manager.check_in(self.take(None)),
PinnedState::Returned { .. } => {
pool_manager.check_in(self.take(PooledConnectionState::CheckedIn {
available_time: Instant::now(),
}))
}
}
}
};
Expand Down
Loading

0 comments on commit e3df089

Please sign in to comment.