Skip to content

Commit

Permalink
Refactor dkg state and its use from trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
pool2win committed Nov 28, 2024
1 parent cb83c84 commit 45b4728
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 90 deletions.
39 changes: 8 additions & 31 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ impl Node {
accept_ready_tx: oneshot::Sender<()>,
) {
log::debug!("Starting... {}", self.bind_address);
let node_id = self.get_node_id().clone();
let state = self.state.clone();
let echo_broadcast_handle = self.echo_broadcast_handle.clone();
// let interval = tokio::time::interval(tokio::time::Duration::from_secs(15));
tokio::spawn(async move {
dkg::trigger::run_dkg_trigger(15000, node_id, state, echo_broadcast_handle, None).await;
});

if self.connect_to_seeds().await.is_err() {
log::info!("Connecting to seeds failed.");
return;
Expand Down Expand Up @@ -214,21 +222,6 @@ impl Node {
let delivery_timeout = self.delivery_timeout;
let reliable_sender = reliable_sender_handle.clone();
initialize_handshake(node_id, state, reliable_sender, delivery_timeout).await;

let node_id = self.get_node_id().clone();
let state = self.state.clone();
let echo_broadcast_handle = self.echo_broadcast_handle.clone();
let reliable_sender_handle = reliable_sender_handle.clone();
tokio::spawn(async move {
dkg::trigger::run_dkg_trigger(
15000,
node_id,
state,
echo_broadcast_handle,
reliable_sender_handle,
)
.await;
});
}
}

Expand Down Expand Up @@ -257,22 +250,6 @@ impl Node {
self.echo_broadcast_handle.clone(),
)
.await;

let node_id = self.get_node_id().clone();
let state = self.state.clone();
let echo_broadcast_handle = self.echo_broadcast_handle.clone();
let reliable_sender_handle = reliable_sender_handle.clone();
let interval = tokio::time::interval(tokio::time::Duration::from_secs(15));
tokio::spawn(async move {
dkg::trigger::run_dkg_trigger(
15000,
node_id,
state,
echo_broadcast_handle,
reliable_sender_handle,
)
.await;
});
} else {
log::debug!("Failed to connect to seed {}", seed);
return Err("Failed to connect to seed".into());
Expand Down
7 changes: 6 additions & 1 deletion src/node/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ mod command_tests {
#[tokio::test]
async fn it_should_run_node_with_command_rx() {
let ctx = EchoBroadcastHandle::start_context();
ctx.expect().returning(EchoBroadcastHandle::default);
ctx.expect().returning(|| {
let mut mock = EchoBroadcastHandle::default();
mock.expect_clone()
.returning(|| EchoBroadcastHandle::default());
mock
});

let (exector, command_rx) = CommandExecutor::new();
let mut node = Node::new()
Expand Down
9 changes: 7 additions & 2 deletions src/node/protocol/dkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ pub(crate) mod trigger;
use crate::node::state::State;

/// Get the max and min signers for the DKG
/// Use the expected number of members in dkg state
pub(crate) async fn get_max_min_signers(state: &State) -> (usize, usize) {
let members = state.membership_handle.get_members().await.unwrap();
let num_members = members.len() + 1;
log::debug!(
"Num members in get max min signers {}",
state.dkg_state.get_expected_members().await.unwrap()
);
let num_members = state.dkg_state.get_expected_members().await.unwrap_or(0) + 1;
(num_members, (num_members * 2).div_ceil(3))
}

#[cfg(test)]
mod tests {
use frost_secp256k1 as frost;
Expand Down
8 changes: 6 additions & 2 deletions src/node/protocol/dkg/round_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async fn build_round1_package(
let participant_identifier = frost::Identifier::derive(sender_id.as_bytes()).unwrap();
let rng = thread_rng();
log::debug!("SIGNERS: {} {}", max_signers, min_signers);

let result = frost::keys::dkg::part1(
participant_identifier,
max_signers as u16,
Expand Down Expand Up @@ -133,8 +134,11 @@ impl Service<Message> for Package {
}),
_message_id,
) => {
log::debug!("Received round one package");
log::info!("Received message {:?}", message);
log::info!(
"Received round one message from {} \n {:?}",
from_sender_id,
message
);
let identifier = frost::Identifier::derive(from_sender_id.as_bytes()).unwrap();
state
.dkg_state
Expand Down
154 changes: 138 additions & 16 deletions src/node/protocol/dkg/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ impl State {
expected_members,
}
}

/// Reset the DKG state for starting a new round
/// Retain the current key package and public key package as
/// they are replaced only on successful completion of DKG
pub async fn reset(&mut self, expected_members: usize) {
self.in_progress = true;
self.received_round1_packages = Round1Map::new();
self.received_round2_packages = Round2Map::new();
self.round1_secret_package = None;
self.round2_secret_package = None;
self.expected_members = expected_members;
}
}

/// Message for state handle to actor communication
Expand Down Expand Up @@ -98,6 +110,15 @@ pub(crate) enum StateMessage {

/// Get the public key package
GetPublicKeyPackage(oneshot::Sender<Option<frost::keys::PublicKeyPackage>>),

/// Set the expected members count
SetExpectedMembers(usize, oneshot::Sender<()>),

/// Get the expected members count
GetExpectedMembers(oneshot::Sender<usize>),

/// Reset the state
ResetState(usize, oneshot::Sender<()>),
}

pub(crate) struct Actor {
Expand Down Expand Up @@ -156,6 +177,17 @@ impl Actor {
StateMessage::GetPublicKeyPackage(respond_to) => {
self.get_public_key_package(respond_to);
}
StateMessage::SetExpectedMembers(count, respond_to) => {
self.state.expected_members = count;
let _ = respond_to.send(());
}
StateMessage::GetExpectedMembers(respond_to) => {
let _ = respond_to.send(self.state.expected_members);
}
StateMessage::ResetState(expected_members, respond_to) => {
self.state.reset(expected_members).await;
let _ = respond_to.send(());
}
}
}
}
Expand All @@ -170,6 +202,10 @@ impl Actor {
.received_round1_packages
.insert(identifier, package);
let received_count = self.state.received_round1_packages.len();
log::info!(
"Received round1 packages count WHEN ADDING = {}",
received_count
);
let _ = respond_to.send(received_count == self.state.expected_members);
}

Expand Down Expand Up @@ -249,28 +285,33 @@ impl Actor {
}
}

/// Start the DKG state actor and return the sender handle
///
/// This function creates a new channel, spawns the actor task to process messages,
/// and returns the sender end of the channel.
pub(crate) fn start_dkg_actor(expected_members: Option<usize>) -> mpsc::Sender<StateMessage> {
let (sender, receiver) = mpsc::channel(1);
let mut actor = Actor::new(receiver, expected_members.unwrap_or(0));

log::debug!("Actor spawning......");
// Spawn the actor task
tokio::spawn(async move {
actor.run().await;
});

sender
}

#[derive(Clone, Debug)]
pub(crate) struct StateHandle {
sender: mpsc::Sender<StateMessage>,
pub(crate) expected_members: Option<usize>,
}

impl StateHandle {
/// Create a new state handle and spawn the actor
pub fn new(expected_members: Option<usize>) -> Self {
let (sender, receiver) = mpsc::channel(1);
// default expected members = 1, as we count ourselves
let mut actor = Actor::new(receiver, expected_members.unwrap_or(1));

// Spawn the actor task
tokio::spawn(async move {
actor.run().await;
});

Self {
sender,
expected_members,
}
let sender = start_dkg_actor(expected_members);
Self { sender }
}

/// Add round1 package to state
Expand Down Expand Up @@ -400,6 +441,40 @@ impl StateHandle {
let _ = self.sender.send(message).await;
rx.await
}

/// Set the expected members count
pub async fn set_expected_members(
&self,
count: usize,
) -> Result<(), oneshot::error::RecvError> {
let (tx, rx) = oneshot::channel();
let message = StateMessage::SetExpectedMembers(count, tx);
let _ = self.sender.send(message).await;
rx.await
}

/// Get the expected members count
pub async fn get_expected_members(&self) -> Result<usize, oneshot::error::RecvError> {
let (tx, rx) = oneshot::channel();
let message = StateMessage::GetExpectedMembers(tx);
let _ = self.sender.send(message).await;
rx.await
}

/// Reset the state
pub async fn reset_state(
&self,
expected_members: usize,
) -> Result<(), oneshot::error::RecvError> {
log::debug!(
"Resetting DKG state with expected members = {}",
expected_members
);
let (tx, rx) = oneshot::channel();
let message = StateMessage::ResetState(expected_members, tx);
let _ = self.sender.send(message).await;
rx.await
}
}

#[cfg(test)]
Expand Down Expand Up @@ -570,6 +645,28 @@ mod dkg_state_tests {
let retrieved_package = rx2.await.unwrap();
assert_eq!(retrieved_package, Some(public_key_package));
}

#[tokio::test]
async fn test_actor_expected_members() {
let (_tx, rx) = mpsc::channel(1);
let mut actor = Actor::new(rx, 3);

// Test initial value
let (tx, rx) = oneshot::channel();
actor.state.expected_members = 3;
let _ = tx.send(actor.state.expected_members);
assert_eq!(rx.await.unwrap(), 3);

// Test setting new value
let (tx, _rx) = oneshot::channel();
actor.state.expected_members = 5;
let _ = tx.send(());

// Test getting updated value
let (tx, rx) = oneshot::channel();
let _ = tx.send(actor.state.expected_members);
assert_eq!(rx.await.unwrap(), 5);
}
}

#[cfg(test)]
Expand All @@ -582,8 +679,8 @@ mod dkg_state_handle_tests {

#[tokio::test]
async fn test_state_handle_new() {
let handle = StateHandle::new(Some(0));
assert!(handle.sender.capacity() > 0);
let state_handle = StateHandle::new(Some(1));
assert!(state_handle.sender.capacity() > 0);
}

#[tokio::test]
Expand Down Expand Up @@ -860,4 +957,29 @@ mod dkg_state_handle_tests {
let received_packages = state_handle.get_received_round2_packages().await.unwrap();
assert_eq!(received_packages.len(), 2);
}

#[tokio::test]
async fn test_state_handle_expected_members() {
let state_handle = StateHandle::new(Some(3));

// Test initial value
let initial_count = state_handle.get_expected_members().await.unwrap();
assert_eq!(initial_count, 3);

// Test setting new value
assert!(state_handle.set_expected_members(5).await.is_ok());

// Test getting updated value
let updated_count = state_handle.get_expected_members().await.unwrap();
assert_eq!(updated_count, 5);
}

#[tokio::test]
async fn test_state_handle_default_expected_members() {
// Test that None defaults to 0 expected member
let state_handle = StateHandle::new(None);

let count = state_handle.get_expected_members().await.unwrap();
assert_eq!(count, 0);
}
}
Loading

0 comments on commit 45b4728

Please sign in to comment.