Skip to content

Commit

Permalink
Refactor MqttError type
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Sep 18, 2023
1 parent f874e8b commit 2bae4ec
Show file tree
Hide file tree
Showing 18 changed files with 190 additions and 136 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.12.0] - 2023-09-18

* Refactor MqttError

## [0.11.4] - 2023-08-10

* Update ntex deps
Expand Down
16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.11.4"
version = "0.12.0"
authors = ["ntex contributors <[email protected]>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand All @@ -9,12 +9,12 @@ categories = ["network-programming"]
keywords = ["MQTT", "IoT", "messaging"]
license = "MIT"
exclude = [".gitignore", ".travis.yml", ".cargo/config"]
edition = "2018"
edition = "2021"

[dependencies]
ntex = "0.7.3"
ntex-util = "0.3.1"
bitflags = "1.3"
ntex = "0.7.4"
ntex-util = "0.3.2"
bitflags = "2.4"
log = "0.4"
pin-project-lite = "0.2"
serde = { version = "1.0", features = ["derive"] }
Expand All @@ -23,12 +23,12 @@ thiserror = "1.0"

[dev-dependencies]
env_logger = "0.10"
ntex-tls = "0.3.0"
ntex-tls = "0.3.1"
rustls = "0.21"
rustls-pemfile = "1.0"
openssl = "0.10"
ntex = { version = "0.7.3", features = ["tokio", "rustls", "openssl"] }
test-case = "3"
ntex = { version = "0.7.4", features = ["tokio", "rustls", "openssl"] }
test-case = "3.2"

[profile.dev]
lto = "off" # cannot build tests with "thin"
Expand Down
32 changes: 23 additions & 9 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,29 @@ pub enum MqttError<E> {
/// Publish handler service error
#[error("Service error")]
Service(E),
/// Handshake error
#[error("Mqtt handshake error: {}", _0)]
Handshake(#[from] HandshakeError<E>),
}

/// Errors which can occur during mqtt connection handshake.
#[derive(Debug, thiserror::Error)]

Check warning on line 19 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L19

Added line #L19 was not covered by tests
pub enum HandshakeError<E> {
/// Handshake service error
#[error("Handshake service error")]
Service(E),
/// Protocol error
#[error("Mqtt protocol error: {}", _0)]
Protocol(#[from] ProtocolError),
/// Handshake timeout
#[error("Handshake timeout")]
HandshakeTimeout,
Timeout,
/// Peer disconnect
#[error("Peer is disconnected, error: {:?}", _0)]
Disconnected(Option<io::Error>),
/// Server error
#[error("Server error: {}", _0)]
ServerError(&'static str),
Server(&'static str),
}

/// Protocol level errors
Expand Down Expand Up @@ -54,6 +65,7 @@ enum ViolationInner {
#[error("{message}; received packet with type `{packet_type:b}`")]
UnexpectedPacket { packet_type: u8, message: &'static str },
}

impl ProtocolViolationError {
pub(crate) fn reason(&self) -> DisconnectReasonCode {
match self.inner {
Expand Down Expand Up @@ -87,30 +99,32 @@ impl ProtocolError {

impl<E> From<io::Error> for MqttError<E> {
fn from(err: io::Error) -> Self {
MqttError::Disconnected(Some(err))
MqttError::Handshake(HandshakeError::Disconnected(Some(err)))

Check warning on line 102 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L102

Added line #L102 was not covered by tests
}
}

impl<E> From<Either<io::Error, io::Error>> for MqttError<E> {
fn from(err: Either<io::Error, io::Error>) -> Self {
MqttError::Disconnected(Some(err.into_inner()))
MqttError::Handshake(HandshakeError::Disconnected(Some(err.into_inner())))

Check warning on line 108 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L108

Added line #L108 was not covered by tests
}
}

impl<E> From<Either<DecodeError, io::Error>> for MqttError<E> {
impl<E> From<Either<DecodeError, io::Error>> for HandshakeError<E> {
fn from(err: Either<DecodeError, io::Error>) -> Self {
match err {
Either::Left(err) => MqttError::Protocol(ProtocolError::Decode(err)),
Either::Right(err) => MqttError::Disconnected(Some(err)),
Either::Left(err) => HandshakeError::Protocol(ProtocolError::Decode(err)),
Either::Right(err) => HandshakeError::Disconnected(Some(err)),

Check warning on line 116 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L115-L116

Added lines #L115 - L116 were not covered by tests
}
}
}

impl<E> From<Either<EncodeError, io::Error>> for MqttError<E> {
fn from(err: Either<EncodeError, io::Error>) -> Self {
match err {
Either::Left(err) => MqttError::Protocol(ProtocolError::Encode(err)),
Either::Right(err) => MqttError::Disconnected(Some(err)),
Either::Left(err) => {
MqttError::Handshake(HandshakeError::Protocol(ProtocolError::Encode(err)))

Check warning on line 125 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L124-L125

Added lines #L124 - L125 were not covered by tests
}
Either::Right(err) => MqttError::Handshake(HandshakeError::Disconnected(Some(err))),

Check warning on line 127 in src/error.rs

View check run for this annotation

Codecov / codecov/patch

src/error.rs#L127

Added line #L127 was not covered by tests
}
}
}
Expand Down
29 changes: 17 additions & 12 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use ntex::time::{Deadline, Millis, Seconds};
use ntex::util::{join, ready, BoxFuture, Ready};

use crate::version::{ProtocolVersion, VersionCodec};
use crate::{error::MqttError, v3, v5};
use crate::{error::HandshakeError, error::MqttError, v3, v5};

/// Mqtt Server
pub struct MqttServer<V3, V5, Err, InitErr> {
Expand Down Expand Up @@ -437,7 +437,11 @@ where
MqttServerImplStateProject::Version { ref mut item } => {
match item.as_mut().unwrap().2.poll_elapsed(cx) {
Poll::Pending => (),
Poll::Ready(_) => return Poll::Ready(Err(MqttError::HandshakeTimeout)),
Poll::Ready(_) => {
return Poll::Ready(Err(MqttError::Handshake(
HandshakeError::Timeout,
)))

Check warning on line 443 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L441-L443

Added lines #L441 - L443 were not covered by tests
}
}

let st = item.as_mut().unwrap();
Expand All @@ -458,16 +462,17 @@ where
unreachable!()
}
Err(RecvError::WriteBackpressure) => {
ready!(st.0.poll_flush(cx, false))
.map_err(|e| MqttError::Disconnected(Some(e)))?;
ready!(st.0.poll_flush(cx, false)).map_err(|e| {
MqttError::Handshake(HandshakeError::Disconnected(Some(e)))
})?;

Check warning on line 467 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L465-L467

Added lines #L465 - L467 were not covered by tests
continue;
}
Err(RecvError::Decoder(err)) => {
Poll::Ready(Err(MqttError::Protocol(err.into())))
}
Err(RecvError::PeerGone(err)) => {
Poll::Ready(Err(MqttError::Disconnected(err)))
}
Err(RecvError::Decoder(err)) => Poll::Ready(Err(MqttError::Handshake(
HandshakeError::Protocol(err.into()),
))),
Err(RecvError::PeerGone(err)) => Poll::Ready(Err(
MqttError::Handshake(HandshakeError::Disconnected(err)),
)),

Check warning on line 475 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L470-L475

Added lines #L470 - L475 were not covered by tests
};
}
}
Expand Down Expand Up @@ -504,9 +509,9 @@ impl<Err, InitErr> Service<(IoBoxed, Deadline)> for DefaultProtocolServer<Err, I
type Future<'f> = Ready<Self::Response, Self::Error> where Self: 'f;

fn call<'a>(&'a self, _: (IoBoxed, Deadline), _: ServiceCtx<'a, Self>) -> Self::Future<'a> {
Ready::Err(MqttError::Disconnected(Some(io::Error::new(
Ready::Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new(

Check warning on line 512 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L512

Added line #L512 was not covered by tests
io::ErrorKind::Other,
format!("Protocol is not supported: {:?}", self.ver),
))))
)))))

Check warning on line 515 in src/server.rs

View check run for this annotation

Codecov / codecov/patch

src/server.rs#L515

Added line #L515 was not covered by tests
}
}
2 changes: 2 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ prim_enum! {
}

bitflags::bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]

Check warning on line 35 in src/types.rs

View check run for this annotation

Codecov / codecov/patch

src/types.rs#L35

Added line #L35 was not covered by tests
pub struct ConnectFlags: u8 {
const USERNAME = 0b1000_0000;
const PASSWORD = 0b0100_0000;
Expand All @@ -43,6 +44,7 @@ bitflags::bitflags! {
}

bitflags::bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]

Check warning on line 47 in src/types.rs

View check run for this annotation

Codecov / codecov/patch

src/types.rs#L47

Added line #L47 was not covered by tests
pub struct ConnectAckFlags: u8 {
const SESSION_PRESENT = 0b0000_0001;
}
Expand Down
6 changes: 1 addition & 5 deletions src/v3/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,7 @@ where
}

async fn _connect(&self) -> Result<Client, ClientError<codec::ConnectAck>> {
let io: IoBoxed = self
.connector
.call(Connect::new(self.address.clone()))
.await?
.into();
let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into();
let pkt = self.pkt.clone();
let max_send = self.max_send;
let max_receive = self.max_receive;
Expand Down
29 changes: 17 additions & 12 deletions src/v3/client/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use ntex::io::DispatchItem;
use ntex::service::{Pipeline, Service, ServiceCall, ServiceCtx};
use ntex::util::{inflight::InFlightService, BoxFuture, Either, HashSet, Ready};

use crate::error::{HandshakeError, MqttError, ProtocolError};
use crate::v3::shared::{Ack, MqttShared};
use crate::v3::{codec, control::ControlResultKind, publish::Publish};
use crate::{error::MqttError, error::ProtocolError};

use super::control::{ControlMessage, ControlResult};

Expand Down Expand Up @@ -90,8 +90,7 @@ where
self.inner.sink.close();
let inner = self.inner.clone();
*shutdown = Some(Box::pin(async move {
let _ =
Pipeline::new(&inner.control).call(ControlMessage::closed()).await;
let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await;
}));
}

Expand Down Expand Up @@ -120,9 +119,9 @@ where
if let Some(pid) = packet_id {
if !inner.inflight.borrow_mut().insert(pid) {
log::trace!("Duplicated packet id for publish packet: {:?}", pid);
return Either::Right(Either::Left(Ready::Err(
MqttError::ServerError("Duplicated packet id for publish packet"),
)));
return Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Server("Duplicated packet id for publish packet"),
))));
}
}
Either::Left(PublishResponse {
Expand All @@ -135,21 +134,27 @@ where
}
DispatchItem::Item((codec::Packet::PublishAck { packet_id }, _)) => {
if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) {
Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e))))
Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(e),
))))

Check warning on line 139 in src/v3/client/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/client/dispatcher.rs#L137-L139

Added lines #L137 - L139 were not covered by tests
} else {
Either::Right(Either::Left(Ready::Ok(None)))
}
}
DispatchItem::Item((codec::Packet::SubscribeAck { packet_id, status }, _)) => {
if let Err(e) = self.inner.sink.pkt_ack(Ack::Subscribe { packet_id, status }) {
Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e))))
Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(e),
))))

Check warning on line 148 in src/v3/client/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/client/dispatcher.rs#L146-L148

Added lines #L146 - L148 were not covered by tests
} else {
Either::Right(Either::Left(Ready::Ok(None)))
}
}
DispatchItem::Item((codec::Packet::UnsubscribeAck { packet_id }, _)) => {
if let Err(e) = self.inner.sink.pkt_ack(Ack::Unsubscribe(packet_id)) {
Either::Right(Either::Left(Ready::Err(MqttError::Protocol(e))))
Either::Right(Either::Left(Ready::Err(MqttError::Handshake(
HandshakeError::Protocol(e),
))))

Check warning on line 157 in src/v3/client/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/client/dispatcher.rs#L155-L157

Added lines #L155 - L157 were not covered by tests
} else {
Either::Right(Either::Left(Ready::Ok(None)))
}
Expand All @@ -161,10 +166,10 @@ where
| codec::Packet::Unsubscribe { .. }),
_,
)) => Either::Right(Either::Left(Ready::Err(
ProtocolError::unexpected_packet(
HandshakeError::Protocol(ProtocolError::unexpected_packet(

Check warning on line 169 in src/v3/client/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/client/dispatcher.rs#L169

Added line #L169 was not covered by tests
pkt.packet_type(),
"Packet of the type is not expected from server",
)
))

Check warning on line 172 in src/v3/client/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/client/dispatcher.rs#L172

Added line #L172 was not covered by tests
.into(),
))),
DispatchItem::Item((pkt, _)) => {
Expand Down Expand Up @@ -377,7 +382,7 @@ mod tests {
))));
let err = f.await.err().unwrap();
match err {
MqttError::ServerError(msg) => {
MqttError::Handshake(HandshakeError::Server(msg)) => {
assert!(msg == "Duplicated packet id for publish packet")
}
_ => panic!(),
Expand Down
41 changes: 26 additions & 15 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use ntex::service::{self, Pipeline, Service, ServiceCall, ServiceCtx, ServiceFac
use ntex::util::buffer::{BufferService, BufferServiceError};
use ntex::util::{inflight::InFlightService, join, BoxFuture, Either, HashSet, Ready};

use crate::error::{MqttError, ProtocolError};
use crate::error::{HandshakeError, MqttError, ProtocolError};
use crate::types::QoS;

use super::control::{
Expand Down Expand Up @@ -46,19 +46,23 @@ where
let fut = join(factories.0.create(session.clone()), factories.1.create(session));
let (publish, control) = fut.await;

let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
let control = control.map_err(|e| MqttError::Service(e.into()))?;
let publish =
publish.map_err(|e| MqttError::Handshake(HandshakeError::Service(e.into())))?;
let control =
control.map_err(|e| MqttError::Handshake(HandshakeError::Service(e.into())))?;

let control = BufferService::new(
16,
// limit number of in-flight messages
InFlightService::new(1, control),
)
.map_err(|err| match err {
BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
BufferServiceError::RequestCanceled => {
MqttError::ServerError("Request handling has been canceled")
BufferServiceError::Service(e) => {
MqttError::Handshake(HandshakeError::Service(E::from(e)))

Check warning on line 61 in src/v3/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/dispatcher.rs#L60-L61

Added lines #L60 - L61 were not covered by tests
}
BufferServiceError::RequestCanceled => MqttError::Handshake(
HandshakeError::Server("Request handling has been canceled"),
),

Check warning on line 65 in src/v3/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/dispatcher.rs#L63-L65

Added lines #L63 - L65 were not covered by tests
});

Ok(
Expand Down Expand Up @@ -145,8 +149,7 @@ where
self.inner.sink.close();
let inner = self.inner.clone();
*shutdown = Some(Box::pin(async move {
let _ =
Pipeline::new(&inner.control).call(ControlMessage::closed()).await;
let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await;
}));
}

Expand Down Expand Up @@ -256,10 +259,14 @@ where
}

if !self.inner.inflight.borrow_mut().insert(packet_id) {
log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id);
return Either::Right(Either::Left(Ready::Err(MqttError::ServerError(
"Duplicated packet id for unsubscribe packet",
))));
log::trace!("Duplicated packet id for subscribe packet: {:?}", packet_id);
return Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::generic_violation(
"Duplicated packet id for subscribe packet",
)),
&self.inner,
ctx,
)));

Check warning on line 269 in src/v3/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/dispatcher.rs#L262-L269

Added lines #L262 - L269 were not covered by tests
}

Either::Right(Either::Right(ControlResponse::new(
Expand All @@ -284,9 +291,13 @@ where

if !self.inner.inflight.borrow_mut().insert(packet_id) {
log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id);
return Either::Right(Either::Left(Ready::Err(MqttError::ServerError(
"Duplicated packet id for unsubscribe packet",
))));
return Either::Right(Either::Right(ControlResponse::new(
ControlMessage::proto_error(ProtocolError::generic_violation(
"Duplicated packet id for unsubscribe packet",
)),
&self.inner,
ctx,
)));

Check warning on line 300 in src/v3/dispatcher.rs

View check run for this annotation

Codecov / codecov/patch

src/v3/dispatcher.rs#L294-L300

Added lines #L294 - L300 were not covered by tests
}

Either::Right(Either::Right(ControlResponse::new(
Expand Down
Loading

0 comments on commit 2bae4ec

Please sign in to comment.