diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 4846d0a5e046b..754d4eb0ea94f 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -28,13 +28,13 @@ use crate::joins::utils::{ calculate_join_output_ordering, get_final_indices_from_bit_map, need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::DisplayAs; use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, hash_utils::create_hashes, + joins::stream_join_utils::StreamJoinStateResult, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, @@ -44,6 +44,7 @@ use crate::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -618,15 +619,14 @@ impl ExecutionPlan for HashJoinExec { on_right, filter: self.filter.clone(), join_type: self.join_type, - left_fut, - visited_left_side: None, right: right_stream, column_indices: self.column_indices.clone(), random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - is_exhausted: false, reservation, + state: HashJoinStreamState::WaitBuildSide, + build_side_state: BuildSideState::Initial(BuildSideInitialState { left_fut }), })) } @@ -789,6 +789,87 @@ where Ok(()) } +/// Represents state of HashJoin build-side. +enum BuildSideState { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +struct BuildSideInitialState { + left_fut: OnceFut, +} + +struct BuildSideReadyState { + left_data: Arc, + visited_left_side: BooleanBufferBuilder, +} + +impl BuildSideState { + /// Tries to extract BuildSideInitialState from BuildSideState enum. + /// Returns an error if state is not Initial. + fn try_into_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSideState::Initial(state) => Ok(state), + _ => Err(DataFusionError::Internal( + "Expected build side in initial state".to_string(), + )), + } + } + + /// Tries to extract BuildSideReadyState from BuildSideState enum. + /// Returns an error if state is not Ready. + fn try_into_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSideState::Ready(state) => Ok(state), + _ => Err(DataFusionError::Internal( + "Expected build side in ready state".to_string(), + )), + } + } + + /// Tries to extract BuildSideReadyState from BuildSideState enum. + /// Returns an error if state is not Ready. + fn try_into_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSideState::Ready(state) => Ok(state), + _ => Err(DataFusionError::Internal( + "Expected build side in ready state".to_string(), + )), + } + } +} + +/// Represents state of HashJoinStream +enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +struct ProcessProbeBatchState { + batch: RecordBatch, +} + +impl HashJoinStreamState { + fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => Err(DataFusionError::Internal( + "Expected hash join stream in ProcessProbeBatch state".to_string(), + )), + } + } +} + /// [`Stream`] for [`HashJoinExec`] that does the actual join. /// /// This stream: @@ -808,20 +889,10 @@ struct HashJoinStream { filter: Option, /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future which builds hash table from left side - left_fut: OnceFut, - /// Which left (probe) side rows have been matches while creating output. - /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct output. - visited_left_side: Option, /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// The join output is complete. For outer joins, this is used to - /// distinguish when the input stream is exhausted and when any unmatched - /// rows are output. - is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns @@ -830,6 +901,10 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, + /// Stream state + state: HashJoinStreamState, + /// Build side state + build_side_state: BuildSideState, } impl RecordBatchStream for HashJoinStream { @@ -1064,24 +1139,63 @@ pub fn equal_rows_arr( impl HashJoinStream { /// Separate implementation function that unpins the [`HashJoinStream`] so - /// that partial borrows work correctly + /// that partial borrows work correctly. + /// + /// Performs state transitions for HashJoinStream according to following diagram: + /// + /// WaitBuildSide + /// │ + /// ▼ + /// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed + /// │ │ + /// │ ▼ + /// └─ ProcessProbeBatch + /// fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, - ) -> Poll>> { + ) -> Poll>> + where + Self: Send, + { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + HashJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initial build-side state + /// + /// Updates build-side state to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + let left_data = ready!(self + .build_side_state + .try_into_initial_mut()? + .left_fut + .get_shared(cx))?; build_timer.done(); // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it - if self.visited_left_side.is_none() - && need_produce_result_in_final(self.join_type) - { + if need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); @@ -1089,124 +1203,167 @@ impl HashJoinStream { self.join_metrics.build_mem_used.add(visited_bitmap_size); } - let visited_left_side = self.visited_left_side.get_or_insert_with(|| { + let visited_left_side = if need_produce_result_in_final(self.join_type) { let num_rows = left_data.num_rows(); - if need_produce_result_in_final(self.join_type) { - // Some join types need to track which row has be matched or unmatched: - // `left semi` join: need to use the bitmap to produce the matched row in the left side - // `left` join: need to use the bitmap to produce the unmatched row in the left side with null - // `left anti` join: need to use the bitmap to produce the unmatched row in the left side - // `full` join: need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - } + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side_state = BuildSideState::Ready(BuildSideReadyState { + left_data, + visited_left_side, }); + + Poll::Ready(Ok(StreamJoinStateResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StreamJoinStateResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch()?; + let build_side = self.build_side_state.try_into_ready_mut()?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(state.batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let mut hashes_buffer = vec![]; - // get next right (probe) input batch - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - // one right batch in the join loop - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - left_data.hash_map(), - left_data.batch(), - &batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - ); - - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - batch.num_rows(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - timer.done(); - result - } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_bit_map( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; - Some(result) - } else { - // end of the join loop - None - } + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + build_side.left_data.hash_map(), + build_side.left_data.batch(), + &state.batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + ); + + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); } - Some(err) => Some(err), - }) + + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + state.batch.num_rows(), + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); + result + } + Err(err) => { + exec_err!("Fail to build join indices in HashJoinExec, error:{err}") + } + }; + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + Ok(StreamJoinStateResult::Ready(Some(result?))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch is required + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + + return Ok(StreamJoinStateResult::Continue); + } + + let build_side = self.build_side_state.try_into_ready()?; + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StreamJoinStateResult::Ready(Some(result?))) } } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5e01ca227cf5a..3a047a8d5b937 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -849,6 +849,23 @@ impl OnceFut { ), } } + + /// Get the result of the computation as a shared reference if it is ready, without consuming it + pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let OnceFutState::Pending(fut) = &mut self.state { + let r = ready!(fut.poll_unpin(cx)); + self.state = OnceFutState::Ready(r); + } + + // Cannot use loop as this would trip up the borrow checker + match &self.state { + OnceFutState::Pending(_) => unreachable!(), + OnceFutState::Ready(r) => Poll::Ready( + r.clone() + .map_err(|e| DataFusionError::External(Box::new(e.clone()))), + ), + } + } } /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and