Skip to content

Commit

Permalink
Move ShMem persisting flag to a new constructor (#2649)
Browse files Browse the repository at this point in the history
* moving shmem persisting to take an owned value, adding test

* clean code updates

* adding imports conditionally

* fixing tests

* moving persistent mmap shmem to custom constructor

* excluding miri properly

* fixing formatting
  • Loading branch information
riesentoaster authored Nov 3, 2024
1 parent 89cff63 commit d4fbe17
Showing 1 changed file with 110 additions and 72 deletions.
182 changes: 110 additions & 72 deletions libafl_bolts/src/shmem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -838,20 +838,93 @@ pub mod unix_shmem {
pub fn filename_path(&self) -> &Option<[u8; MAX_MMAP_FILENAME_LEN]> {
&self.filename_path
}
}

impl ShMem for MmapShMem {
fn id(&self) -> ShMemId {
self.id
}
}

impl Deref for MmapShMem {
type Target = [u8];

fn deref(&self) -> &[u8] {
// # Safety
// No user-provided potentially unsafe parameters.
unsafe { slice::from_raw_parts(self.map, self.map_size) }
}
}

impl DerefMut for MmapShMem {
fn deref_mut(&mut self) -> &mut [u8] {
// # Safety
// No user-provided potentially unsafe parameters.
unsafe { slice::from_raw_parts_mut(self.map, self.map_size) }
}
}

impl Drop for MmapShMem {
fn drop(&mut self) {
// # Safety
// No user-provided potentially unsafe parameters.
// Mutable borrow so no possible race.
unsafe {
assert!(
!self.map.is_null(),
"Map should never be null for MmapShMem (on Drop)"
);

munmap(self.map as *mut _, self.map_size);
self.map = ptr::null_mut();

assert!(
self.shm_fd != -1,
"FD should never be -1 for MmapShMem (on Drop)"
);

// None in case we didn't [`shm_open`] this ourselves, but someone sent us the FD.
// log::info!("Dropping {:#?}", self.filename_path);
// if let Some(filename_path) = self.filename_path {
// shm_unlink(filename_path.as_ptr() as *const _);
// }
// We cannot shm_unlink here!
// unlike unix common shmem we don't have refcounter.
// so there's no guarantee that there's no other process still using it.
}
}
}

/// If called, the shared memory will also be available in subprocesses.
/// A [`ShMemProvider`] which uses [`shm_open`] and [`mmap`] to provide shared memory mappings.
#[cfg(unix)]
#[derive(Clone, Debug)]
pub struct MmapShMemProvider {}

impl MmapShMemProvider {
/// Creates a new shared memory mapping, which is available in other processes.
///
/// Only available on UNIX systems at the moment.
///
/// You likely want to pass the [`crate::shmem::ShMemDescription`] and reopen the shared memory in the child process using [`crate::shmem::ShMemProvider::shmem_from_description`].
/// You likely want to pass the [`crate::shmem::ShMemDescription`] of the returned [`ShMem`]
/// and reopen the shared memory in the child process using [`crate::shmem::ShMemProvider::shmem_from_description`].
///
/// # Errors
///
/// This function will return an error if the appropriate flags could not be extracted or set.
pub fn persist_for_child_processes(&self) -> Result<&Self, Error> {
#[cfg(any(unix, doc))]
pub fn new_shmem_persistent(
&mut self,
map_size: usize,
) -> Result<<Self as ShMemProvider>::ShMem, Error> {
let shmem = self.new_shmem(map_size)?;

let fd = shmem.shm_fd;

// # Safety
// No user-provided potentially unsafe parameters.
// FFI Calls.
unsafe {
let flags = fcntl(self.shm_fd, libc::F_GETFD);
let flags = fcntl(fd, libc::F_GETFD);

if flags == -1 {
return Err(Error::os_error(
Expand All @@ -860,23 +933,17 @@ pub mod unix_shmem {
));
}

if fcntl(self.shm_fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC) == -1 {
if fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC) == -1 {
return Err(Error::os_error(
io::Error::last_os_error(),
"Failed to set FD flags",
));
}
}

Ok(self)
Ok(shmem)
}
}

/// A [`ShMemProvider`] which uses [`shm_open`] and [`mmap`] to provide shared memory mappings.
#[cfg(unix)]
#[derive(Clone, Debug)]
pub struct MmapShMemProvider {}

unsafe impl Send for MmapShMemProvider {}

#[cfg(unix)]
Expand Down Expand Up @@ -919,61 +986,6 @@ pub mod unix_shmem {
}
}

impl ShMem for MmapShMem {
fn id(&self) -> ShMemId {
self.id
}
}

impl Deref for MmapShMem {
type Target = [u8];

fn deref(&self) -> &[u8] {
// # Safety
// No user-provided potentially unsafe parameters.
unsafe { slice::from_raw_parts(self.map, self.map_size) }
}
}

impl DerefMut for MmapShMem {
fn deref_mut(&mut self) -> &mut [u8] {
// # Safety
// No user-provided potentially unsafe parameters.
unsafe { slice::from_raw_parts_mut(self.map, self.map_size) }
}
}

impl Drop for MmapShMem {
fn drop(&mut self) {
// # Safety
// No user-provided potentially unsafe parameters.
// Mutable borrow so no possible race.
unsafe {
assert!(
!self.map.is_null(),
"Map should never be null for MmapShMem (on Drop)"
);

munmap(self.map as *mut _, self.map_size);
self.map = ptr::null_mut();

assert!(
self.shm_fd != -1,
"FD should never be -1 for MmapShMem (on Drop)"
);

// None in case we didn't [`shm_open`] this ourselves, but someone sent us the FD.
// log::info!("Dropping {:#?}", self.filename_path);
// if let Some(filename_path) = self.filename_path {
// shm_unlink(filename_path.as_ptr() as *const _);
// }
// We cannot shm_unlink here!
// unlike unix common shmem we don't have refcounter.
// so there's no guarantee that there's no other process still using it.
}
}
}

/// The default sharedmap impl for unix using shmctl & shmget
#[derive(Clone, Debug)]
pub struct CommonUnixShMem {
Expand Down Expand Up @@ -1622,16 +1634,42 @@ mod tests {

use crate::{
shmem::{ShMemProvider, StdShMemProvider},
AsSlice, AsSliceMut,
AsSlice, AsSliceMut, Error,
};

#[test]
#[serial]
#[cfg_attr(miri, ignore)]
fn test_shmem_service() {
let mut provider = StdShMemProvider::new().unwrap();
let mut map = provider.new_shmem(1024).unwrap();
fn test_shmem_service() -> Result<(), Error> {
let mut provider = StdShMemProvider::new()?;
let mut map = provider.new_shmem(1024)?;
map.as_slice_mut()[0] = 1;
assert!(map.as_slice()[0] == 1);
assert_eq!(1, map.as_slice()[0]);
Ok(())
}

#[test]
#[cfg(all(unix, not(miri)))]
#[cfg_attr(miri, ignore)]
fn test_persist_shmem() -> Result<(), Error> {
use std::thread;

use crate::shmem::{MmapShMemProvider, ShMem as _};

let mut provider = MmapShMemProvider::new()?;
let mut shmem = provider.new_shmem_persistent(1)?;
shmem.fill(0);

let description = shmem.description();

let handle = thread::spawn(move || -> Result<(), Error> {
let mut provider = MmapShMemProvider::new()?;
let mut shmem = provider.shmem_from_description(description)?;
shmem.as_slice_mut()[0] = 1;
Ok(())
});
handle.join().unwrap()?;
assert_eq!(1, shmem.as_slice()[0]);
Ok(())
}
}

0 comments on commit d4fbe17

Please sign in to comment.