Skip to content

Commit

Permalink
fix: prevent 2 native connections on same jid
Browse files Browse the repository at this point in the history
  • Loading branch information
valeriansaliou committed Aug 3, 2024
1 parent 7ad8b66 commit 8685775
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src-tauri/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use futures::stream::{SplitSink, SplitStream, StreamExt};
use futures::SinkExt;
use jid::FullJid;
use jid::{BareJid, FullJid};
use log::{debug, error, info, warn};
use serde::Serialize;
use std::collections::HashMap;
Expand Down Expand Up @@ -52,6 +52,8 @@ pub enum ConnectionState {
pub enum ConnectError {
#[error("Invalid JID, cannot connect")]
InvalidJid,
#[error("Another connection is bound on the JID")]
AnotherConnectionBound,
#[error("Connection identifier already exists")]
ConnectionAlreadyExists,
}
Expand Down Expand Up @@ -87,6 +89,7 @@ pub enum PollOutputError {
* ************************************************************************* */

struct ConnectionClient {
jid: BareJid,
sender: UnboundedSender<Packet>,
read_handle: JoinHandle<()>,
write_handle: JoinHandle<()>,
Expand Down Expand Up @@ -282,16 +285,38 @@ pub fn connect<R: Runtime>(
) -> Result<(), ConnectError> {
info!("Connection #{} connect requested on JID: {}", id, jid);

// Parse JID
let jid_full = FullJid::new(jid).or(Err(ConnectError::InvalidJid))?;
let jid_bare = jid_full.to_bare();

// Assert that connection identifier does not already exist
if state.connections.read().unwrap().contains_key(id) {
return Err(ConnectError::ConnectionAlreadyExists);
}

// Parse JID
let jid = FullJid::new(jid).or(Err(ConnectError::InvalidJid))?;
// Assert that another connection with this JID does not already exist in \
// the global state. This prevents connection manager mis-uses where the \
// implementor client would request multiple parallel connections on the \
// same JID.
{
// Scan all connections in the state
let state_connections = state.connections.read().unwrap();

for (connection_id, connection) in (&*state_connections).into_iter() {
// Found another active connection in the state on the same JID?
if jid_bare == connection.jid {
error!(
"Connection #{} connect request found to conflict with: #{}",
id, connection_id
);

return Err(ConnectError::AnotherConnectionBound);
}
}
};

// Create new client
let mut client = Client::new(jid, password);
let mut client = Client::new(jid_full, password);

// Connections are single-use only
client.set_reconnect(false);
Expand Down Expand Up @@ -340,6 +365,7 @@ pub fn connect<R: Runtime>(
state_connections.insert(
id.to_string(),
ConnectionClient {
jid: jid_bare,
sender: tx,
read_handle,
write_handle,
Expand Down

0 comments on commit 8685775

Please sign in to comment.