Skip to content

Commit

Permalink
Make Message::write_out safe
Browse files Browse the repository at this point in the history
I initially thought it was sound because we hold a mutable reference to
prevent the Message being changed under us, but interior mutability
can violate that.
  • Loading branch information
bmerry committed Mar 28, 2024
1 parent 8cf3f69 commit 79ba67c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 43 deletions.
96 changes: 57 additions & 39 deletions src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ where
/// `target` must not be empty.
#[inline]
#[must_use]
unsafe fn append_byte(target: Out<[u8]>, value: u8) -> Out<[u8]> {
fn append_byte(target: Out<[u8]>, value: u8) -> Out<[u8]> {
let (prefix, suffix) = target.split_at_out(1);
prefix.get_unchecked_out(0).write(value);
prefix.get_out(0).unwrap().write(value);
suffix
}

Expand All @@ -90,10 +90,13 @@ where

/// Write the message into a buffer.
///
/// # Safety
/// It returns any unused part of the buffer.
///
/// # Panics
///
/// The target must have size of at least [write_size](Self::write_size).
pub unsafe fn write_out(&self, mut target: Out<[u8]>) {
/// This will panic if the target is smaller than the value returned by
/// [write_size](Self::write_size).
pub fn write_out<'a>(&self, mut target: Out<'a, [u8]>) -> Out<'a, [u8]> {
target = Self::append_byte(target, Self::type_symbol(self.mtype));
target = Self::append_bytes(target, self.name.as_ref());
if let Some(mid) = self.mid {
Expand All @@ -120,7 +123,7 @@ where
}
}
}
let _ = Self::append_byte(target, b'\n');
Self::append_byte(target, b'\n')
}

/// Get the number of bytes needed by [write_out](Self::write_out).
Expand Down Expand Up @@ -150,42 +153,15 @@ where
bytes.0
}

/// Get the size and a callback to write the message.
///
/// The callback panics if the provided buffer doesn't match the returned
/// size.
///
/// # Example
///
/// ```
/// use uninit::prelude::*;
/// # use _lib::message::{Message, MessageType};
/// # let message: Message<&[u8], &[u8]> = Message::new(MessageType::Request, &b""[..], None, vec![]);
///
/// let (size, callback) = message.write_size_callback();
/// let mut out = vec![0u8; size];
/// callback(out.as_out());
/// ```
pub fn write_size_callback(&self) -> (usize, impl Fn(Out<[u8]>) + '_) {
let size = self.write_size();
let callback = move |out: Out<[u8]>| {
if out.len() != size {
panic!("Buffer has the wrong size");
}
// SAFETY: this lambda captures &self, so the length cannot change.
unsafe {
self.write_out(out);
}
};
(size, callback)
}

/// Encode the message to a [Vec]
pub fn to_vec(&self) -> Vec<u8> {
let (size, callback) = self.write_size_callback();
let size = self.write_size();
let mut vec = Vec::with_capacity(size);
callback(vec.get_backing_buffer());
// SAFETY: we've used the callback to initialize all elements.
let remain = self.write_out(vec.get_backing_buffer());
if !remain.is_empty() {
panic!("Size of message changed during formatting.");
}
// SAFETY: we've verified that write_out initialized all elements.
unsafe {
vec.set_len(size);
}
Expand All @@ -197,6 +173,9 @@ where
mod test {
use super::*;

use rstest::*;
use std::cell::Cell;

/// Create a Message that requires more than usize bytes.
#[test]
#[should_panic(expected = "message size should not exceed usize::MAX")]
Expand All @@ -223,4 +202,43 @@ mod test {
Message::new(MessageType::Request, &b"big message"[..], None, arguments);
message.write_size();
}

/// Evil Message that uses interior mutability to change length dynamically
#[rstest]
#[case(100)]
#[case(-100)]
#[should_panic]
fn change_size(#[case] delta: isize) {
#[derive(Clone)]
struct EvilData {
length: Cell<isize>,
delta: isize,
}

impl EvilData {
fn new(initial: isize, delta: isize) -> Self {
EvilData {
length: Cell::new(initial),
delta,
}
}
}

impl AsRef<[u8]> for EvilData {
fn as_ref(&self) -> &[u8] {
// Change by delta every call
let cur = self.length.get();
self.length.set(cur + self.delta);
return &[b'x'; 10000][..cur as usize];
}
}

let message: Message<EvilData, &[u8]> = Message::new(
MessageType::Request,
EvilData::new(5000, delta),
None,
Vec::new(),
);
let _ = message.to_vec();
}
}
15 changes: 11 additions & 4 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

//! The basic katcp message type

use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
Expand Down Expand Up @@ -169,10 +169,17 @@ impl PyMessage {
mid: self.mid,
arguments,
};
let (size, callback) = message.write_size_callback();
let size = message.write_size();
PyBytes::new_bound_with(py, size, |bytes: &mut [u8]| {
callback(bytes.as_out());
Ok(())
let remain = message.write_out(bytes.as_out());
if !remain.is_empty() {
// This should be unreachable, because we hold the GIL.
Err(PyRuntimeError::new_err(
"Message changed size during formatting",
))
} else {
Ok(())
}
})
}
}
Expand Down

0 comments on commit 79ba67c

Please sign in to comment.