diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index 5d14a1517129..b32164f682fa 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -17,7 +17,10 @@ //! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations -use hashbrown::raw::{Bucket, RawTable}; +use hashbrown::{ + hash_table::HashTable, + raw::{Bucket, RawTable}, +}; use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. @@ -173,3 +176,71 @@ impl RawTableAllocExt for RawTable { } } } + +/// Extension trait for hash browns [`HashTable`] to account for allocations. +pub trait HashTableAllocExt { + /// Item type. + type T; + + /// Insert new element into table and increase + /// `accounting` by any newly allocated bytes. + /// + /// Returns the bucket where the element was inserted. + /// Note that allocation counts capacity, not size. + /// + /// # Example: + /// ``` + /// # use datafusion_common::utils::proxy::HashTableAllocExt; + /// # use hashbrown::hash_table::HashTable; + /// let mut table = HashTable::new(); + /// let mut allocated = 0; + /// let hash_fn = |x: &u32| (*x as u64) % 1000; + /// // pretend 0x3117 is the hash value for 1 + /// table.insert_accounted(1, hash_fn, &mut allocated); + /// assert_eq!(allocated, 64); + /// + /// // insert more values + /// for i in 0..100 { table.insert_accounted(i, hash_fn, &mut allocated); } + /// assert_eq!(allocated, 400); + /// ``` + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ); +} + +impl HashTableAllocExt for HashTable +where + T: Eq, +{ + type T = T; + + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ) { + let hash = hasher(&x); + + // NOTE: `find_entry` does NOT grow! + match self.find_entry(hash, |y| y == &x) { + Ok(_occupied) => {} + Err(_absent) => { + if self.len() == self.capacity() { + // need to request more memory + let bump_elements = self.capacity().max(16); + let bump_size = bump_elements * size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + + self.reserve(bump_elements, &hasher); + } + + // still need to insert the element since first try failed + self.entry(hash, |y| y == &x, hasher).insert(x); + } + } + } +} diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 5bf30b724d0b..45d467f133bf 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -23,7 +23,9 @@ use std::{cmp::Ordering, sync::Arc}; mod pool; pub mod proxy { - pub use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; + pub use datafusion_common::utils::proxy::{ + HashTableAllocExt, RawTableAllocExt, VecAllocExt, + }; } pub use pool::*; diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index 59280a3abbdb..8febbdd5b1f9 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -28,7 +28,7 @@ use arrow::array::{ use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; use std::mem::{size_of, swap}; @@ -215,7 +215,7 @@ where /// Should the output be String or Binary? output_type: OutputType, /// Underlying hash set for each distinct value - map: hashbrown::raw::RawTable>, + map: hashbrown::hash_table::HashTable>, /// Total size of the map in bytes map_size: usize, /// In progress arrow `Buffer` containing all values @@ -246,7 +246,7 @@ where pub fn new(output_type: OutputType) -> Self { Self { output_type, - map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), + map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], // first offset is always 0 @@ -387,7 +387,7 @@ where let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); // is value is already present in the set? - let entry = self.map.get_mut(hash, |header| { + let entry = self.map.find_mut(hash, |header| { // compare value if hashes match if header.len != value_len { return false; @@ -425,7 +425,7 @@ where // value is not "small" else { // Check if the value is already present in the set - let entry = self.map.get_mut(hash, |header| { + let entry = self.map.find_mut(hash, |header| { // compare value if hashes match if header.len != value_len { return false; diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 8af35510dd6c..4148c5ffa7c7 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray; use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::fmt::Debug; use std::sync::Arc; @@ -122,7 +122,7 @@ where /// Should the output be StringView or BinaryView? output_type: OutputType, /// Underlying hash set for each distinct value - map: hashbrown::raw::RawTable>, + map: hashbrown::hash_table::HashTable>, /// Total size of the map in bytes map_size: usize, @@ -148,7 +148,7 @@ where pub fn new(output_type: OutputType) -> Self { Self { output_type, - map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), + map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, builder: GenericByteViewBuilder::new(), random_state: RandomState::new(), @@ -274,7 +274,7 @@ where // get the value as bytes let value: &[u8] = value.as_ref(); - let entry = self.map.get_mut(hash, |header| { + let entry = self.map.find_mut(hash, |header| { let v = self.builder.get_value(header.view_idx); if v.len() != value.len() {