From 450de67d747e576b00d6cd97e0da5fe7beab57d2 Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Wed, 27 Nov 2024 20:11:23 -0500 Subject: [PATCH] Import reduce naive from burn (#314) * import reduce from burn * remove autotune for now * impl complete test for naive reduce * Add line support to naive reduction and test * clean and reorganize code and add doc * Add comments to test * run cargo fmt * Fix ArgMin and ArgMax and unlock tests * add clippy exception for comptime if --- crates/cubecl-reduce/src/instructions.rs | 5 + crates/cubecl-reduce/src/lib.rs | 6 +- crates/cubecl-reduce/src/naive.rs | 224 ++++++++++ crates/cubecl-reduce/src/sum.rs | 110 ----- crates/cubecl-reduce/src/test.rs | 530 +++++++++++++++-------- 5 files changed, 588 insertions(+), 287 deletions(-) create mode 100644 crates/cubecl-reduce/src/instructions.rs create mode 100644 crates/cubecl-reduce/src/naive.rs delete mode 100644 crates/cubecl-reduce/src/sum.rs diff --git a/crates/cubecl-reduce/src/instructions.rs b/crates/cubecl-reduce/src/instructions.rs new file mode 100644 index 00000000..ffc0eef0 --- /dev/null +++ b/crates/cubecl-reduce/src/instructions.rs @@ -0,0 +1,5 @@ +pub struct ReduceArgMax; +pub struct ReduceArgMin; +pub struct ReduceMean; +pub struct ReduceSum; +pub struct ReduceProd; diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index 5f157485..dcd0be2f 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -1,4 +1,8 @@ -pub mod sum; +mod instructions; +mod naive; #[cfg(feature = "export_tests")] pub mod test; + +pub use instructions::*; +pub use naive::*; diff --git a/crates/cubecl-reduce/src/naive.rs b/crates/cubecl-reduce/src/naive.rs new file mode 100644 index 00000000..d655a847 --- /dev/null +++ b/crates/cubecl-reduce/src/naive.rs @@ -0,0 +1,224 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum}; + +/// An instruction for the [reduce_naive](reduce_naive) algorithm. +#[cube] +pub trait ReduceNaiveInstruction: Send + Sync + 'static { + /// The reduction accumulator. + /// The implement works on lines. Most likely, the accumulator is `Line` + /// for some CubePrimitive type `T` instead of simply `T`. + type Accumulator: CubeType; + + /// Initialize the accumulator with a null value for the reduction. + /// + /// This could be called many time during reduction. It is required + /// that reducing the initial accumulator any number of times do not change the outcome + /// of the reduction. For example, adding 0s in a sum do not change the outcome. + fn init_accumulator(line_size: u32) -> Self::Accumulator; + + /// Reduce `current_value` into `accumulator`. + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32); + + /// Write the result of the reduction stored in `accumulator` into `output[index]`. + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + shape_reduce_dim: u32, + ); +} + +/// A naive implementation of the reduction algorithm. +/// +/// Each thread with absolute position P is responsible +/// to compute the reduction corresponding to index P of the `output`. +#[cube] +pub fn reduce_naive, EI: Numeric, EO: Numeric>( + input: &Tensor>, + output: &mut Tensor>, + dim: u32, +) { + if ABSOLUTE_POS >= output.len() * output.line_size() { + return; + } + + // Compute the first index where to start the reduction for the current thread. + // First, compute the coordinate corresponding to the ABSOLUTE_POS element of the output tensor + // Then, use the strides of the input tensor to find the index of the same coordinate + // in the input tensor. + let mut offset_input = 0; + for axis in 0..input.rank() { + let coordinate = (ABSOLUTE_POS / output.stride(axis)) % output.shape(axis); + offset_input += coordinate * input.stride(axis); + } + + // Reduce all the lines along `dim` for the previously computed offset. + let mut accumulator = RD::init_accumulator(input.line_size()); + for i in 0..input.shape(dim) { + let index = i * input.stride(dim) + offset_input; + RD::accumulate( + &mut accumulator, + unsafe { *input.index_unchecked(index) }, + i, + ); + } + + // Write the local outcome into output. + RD::write::(output, accumulator, ABSOLUTE_POS, input.shape(dim)); +} + +// Implementations for common instructions. + +#[cube] +impl ReduceNaiveInstruction for ReduceSum { + type Accumulator = Line; + + fn init_accumulator(line_size: u32) -> Line { + Line::empty(line_size).fill(EI::from_int(0)) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { + *accumulator += current_value; + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + output[index] = Line::cast_from(accumulator); + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceProd { + type Accumulator = Line; + + fn init_accumulator(line_size: u32) -> Line { + Line::empty(line_size).fill(EI::from_int(1)) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { + *accumulator *= current_value; + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + output[index] = Line::cast_from(accumulator); + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceMean { + type Accumulator = Line; + + fn init_accumulator(line_size: u32) -> Self::Accumulator { + Line::empty(line_size).fill(EI::from_int(0)) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { + *accumulator += current_value; + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + shape_reduce_dim: u32, + ) { + output[index] = Line::cast_from( + accumulator / Line::empty(output.line_size()).fill(EI::cast_from(shape_reduce_dim)), + ); + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceArgMax { + type Accumulator = (Line, Line); + + fn init_accumulator(line_size: u32) -> Self::Accumulator { + ( + // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 + Line::empty(line_size).fill(EI::MIN), + Line::empty(line_size).fill(0u32), + ) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { + let (max, index) = accumulator; + #[allow(clippy::collapsible_else_if)] + if comptime!(current_value.size() > 1) { + #[unroll] + for k in 0..current_value.size() { + if current_value[k] > max[k] { + max[k] = current_value[k]; + index[k] = i; + } + } + } else { + if current_value > *max { + *max = current_value; + *index = Line::new(i); + } + } + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + let (_, position) = accumulator; + output[index] = Line::cast_from(position) + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceArgMin { + type Accumulator = (Line, Line); + + fn init_accumulator(line_size: u32) -> Self::Accumulator { + ( + // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 + Line::empty(line_size).fill(EI::MAX), + Line::empty(line_size).fill(0u32), + ) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { + let (min, index) = accumulator; + #[allow(clippy::collapsible_else_if)] + if comptime!(current_value.size() > 1) { + #[unroll] + for k in 0..current_value.size() { + if current_value[k] < min[k] { + min[k] = current_value[k]; + index[k] = i; + } + } + } else { + if current_value < *min { + *min = current_value; + *index = Line::new(i); + } + } + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + let (_, position) = accumulator; + output[index] = Line::cast_from(position) + } +} diff --git a/crates/cubecl-reduce/src/sum.rs b/crates/cubecl-reduce/src/sum.rs deleted file mode 100644 index 4b964d86..00000000 --- a/crates/cubecl-reduce/src/sum.rs +++ /dev/null @@ -1,110 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct ReduceConfig { - pub line_size: u32, - pub max_num_planes: u32, -} - -/// Compute the sum of all elements of `input` and write it to the first element of `output`. -/// -/// This doesn't reduce values across lines. For a version that does, use [reduce_sum_lined]. -/// -/// This is a work in progress toward a more general multi-dimensional reduce kernel. -#[cube(launch_unchecked)] -pub fn reduce_sum( - input: &Tensor>, - output: &mut Tensor>, - #[comptime] config: ReduceConfig, -) { - reduce_sum_vector(&input.to_slice(), &mut output.to_slice_mut(), config); -} - -/// Compute the sum of all elements of `input` and write it to the first element of `output`. -/// -/// This reduces values across lines. For a version that doesn't, use [reduce_sum]. -/// -/// This is a work in progress toward a more general multi-dimensional reduce kernel. -#[cube(launch_unchecked)] -pub fn reduce_sum_lined( - input: &Tensor>, - output: &mut Tensor, - #[comptime] config: ReduceConfig, -) { - let mut tmp = SharedMemory::new_lined(1, config.line_size); - reduce_sum_vector(&input.to_slice(), &mut tmp.to_slice_mut(), config); - reduce_sum_lines(&tmp.to_slice(), &mut output.to_slice_mut(), 1_u32); -} - -/// Compute the sum of all elements of `input` and write it to the first element of `output`. -#[cube] -pub fn reduce_sum_vector( - input: &Slice>, - output: &mut SliceMut>, - #[comptime] config: ReduceConfig, -) { - let plane_id = UNIT_POS / PLANE_DIM; - let num_planes = div_ceil(CUBE_DIM, PLANE_DIM); - - // Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration. - let num_iterations = div_ceil(input.len(), CUBE_DIM); - - let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size()); - memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0)); - - // For each iteration, each plane reduces PLANE_DIM lines into a single line. Then, we accumulate the results - // into the memory. Thus, after the loop, the reduction of the memory yields the expected output. - for i in 0..num_iterations { - let index = i * CUBE_DIM + plane_id * PLANE_DIM + UNIT_POS_PLANE; - let value = select( - index < input.len(), - input[index], - Line::empty(config.line_size).fill(N::from_int(0)), - ); - let sum = plane_sum(value); - if UNIT_POS_PLANE == 0 { - memory[plane_id] += sum; - } - } - - // Make sure that each local sum is completed and written to memory. - sync_units(); - - // Sum each elements in memory - let sum = plane_sum(select( - UNIT_POS_PLANE < num_planes, - memory[UNIT_POS_PLANE], - Line::empty(config.line_size).fill(N::from_int(0)), - )); - if UNIT_POS == 0 { - output[0] = sum; - } -} - -/// For each line, sum all elements and write the result into the corresponding element of output. -#[cube] -pub fn reduce_sum_lines( - input: &Slice>, - output: &mut SliceMut, - #[comptime] length: u32, -) { - if UNIT_POS < length { - let line = input[UNIT_POS]; - - let mut sum = N::from_int(0); - - #[unroll] - for k in 0..line.size() { - sum += line[k]; - } - - output[UNIT_POS] = sum; - } -} - -// Integer division rounded up. -#[cube] -fn div_ceil(a: u32, b: u32) -> u32 { - a / b + ((a % b) > 0) as u32 -} diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 148ab8a9..1a5f75ac 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -1,26 +1,28 @@ #![allow(missing_docs)] -use cubecl_core::{prelude::*, Feature}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; -use crate::sum::{reduce_sum, reduce_sum_lined, ReduceConfig}; +use crate::{ + reduce_naive, ReduceArgMax, ReduceArgMin, ReduceMean, ReduceNaiveInstruction, ReduceProd, + ReduceSum, +}; -#[macro_export] -macro_rules! impl_test_reduce_sum_vector { - ($float:ident, [$(($num_values:expr, $cube_size:expr, $line_size:expr)),*]) => { - ::paste::paste! { - $( - #[test] - pub fn []() { - TestCase::<$float>::sum_vector(32, 32, 1).run::(&Default::default()); - } - )* - } - }; +// Simple kernel to launch tests. +#[cube(launch_unchecked)] +pub fn naive_reduce_dim_kernel>( + input: &Tensor>, + output: &mut Tensor>, + dim: u32, +) { + reduce_naive::(input, output, dim) } +// This macro generate all the tests. #[macro_export] macro_rules! testgen_reduce { - ([$($float:ident),*]) => { + // Generate all the tests for a list of types. + ([$($float:ident), *]) => { mod test_reduce { use super::*; ::paste::paste! { @@ -33,212 +35,388 @@ macro_rules! testgen_reduce { } }; + // Generate all the tests for a specific float type. ($float:ident) => { - use super::*; - use cubecl_core::as_type; - use cubecl_core::prelude::Float; - use cubecl_core::CubeCount; use cubecl_reduce::test::TestCase; + use cubecl_core::prelude::CubeCount; - $crate::impl_test_reduce_sum_vector!( + $crate::impl_test_reduce!( $float, [ - (32, 32, 1), - (64, 32, 1), - (100, 32, 1), - (1000, 32, 1), - (2048, 32, 1), - (32, 64, 1), - (64, 64, 1), - (100, 64, 1), - (1000, 64, 1), - (2048, 64, 1), - (32, 1024, 1), - (64, 1024, 1), - (100, 1024, 1), - (1000, 1024, 1), - (2048, 1024, 1), - (32, 32, 2), - (64, 32, 2), - (100, 32, 2), - (1000, 32, 2), - (2048, 32, 2), - (32, 64, 2), - (64, 64, 2), - (100, 64, 2), - (1000, 64, 2), - (2048, 64, 2), - (32, 1024, 2), - (64, 1024, 2), - (100, 1024, 2), - (1000, 1024, 2), - (2048, 1024, 2), - (32, 32, 4), - (64, 32, 4), - (100, 32, 4), - (1000, 32, 4), - (2048, 32, 4), - (32, 64, 4), - (64, 64, 4), - (100, 64, 4), - (1000, 64, 4), - (2048, 64, 4), - (32, 1024, 4), - (64, 1024, 4), - (100, 1024, 4), - (1000, 1024, 4), - (2048, 1024, 4) + { + id: "reduce_columns_small_matrix_row_major", + shape: [4, 8], + stride: [8, 1], + reduce_dim: 0, + cube_count: CubeCount::Static(1, 1, 1), + cube_dim: CubeDim::new(4, 8, 1), + line_size: 1, + }, + { + id: "reduce_columns_large_matrix_row_major", + shape: [8, 256], + stride: [256, 1], + reduce_dim: 1, + cube_count: CubeCount::Static(8, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "reduce_rows_large_matrix_row_major", + shape: [8, 256], + stride: [256, 1], + reduce_dim: 0, + cube_count: CubeCount::Static(8, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "rank_three_tensor", + shape: [16, 16, 16], + stride: [1, 256, 16], + reduce_dim: 2, + cube_count: CubeCount::Static(4, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "rank_three_tensor_unexact_shape", + shape: [11, 12, 13], + stride: [156, 13, 1], + reduce_dim: 1, + cube_count: CubeCount::Static(4, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "reduce_rows_large_matrix_row_major_line_size_four", + shape: [32, 64], + stride: [64, 1], + reduce_dim: 0, + cube_count: CubeCount::Static(8, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 4, + } ] ); }; } -#[derive(Debug)] -pub struct TestTensorParts { - pub values: Vec, - pub stride: Vec, - pub shape: Vec, - pub line_size: u8, -} +// For a given tensor description and cube settings +// run the tests for `ReduceSum`, `ReduceProd`, `ReduceMean`, `ReduceArgMax` and `ReduceArgMin` +// for all implementations. +// For each test, a reference reduction is computed on the CPU to compare the outcome of the kernel. +#[macro_export] +macro_rules! impl_test_reduce { + ( + $float:ident, + [ + $( + { + id: $id:literal, + shape: $shape:expr, + stride: $stride:expr, + reduce_dim: $reduce_dim:expr, + cube_count: $cube_count:expr, + cube_dim: $cube_dim:expr, + line_size: $line_size:expr, + } + ),* + ]) => { + ::paste::paste! { + $( + #[test] + pub fn [< reduce_sum_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_sum_dim_naive::<$float, TestRuntime>(&Default::default()); + } -impl TestTensorParts { - pub fn new_vector(values: Vec) -> Self { - let shape = vec![values.len()]; - Self { - values, - stride: vec![1], - shape, - line_size: 1, - } - } + #[test] + pub fn [< reduce_prod_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_prod_dim_naive::<$float, TestRuntime>(&Default::default()); + } - pub fn range_vector(stop: usize) -> Self { - let values = (0..stop).map(|x| N::new(x as f32)).collect(); - Self::new_vector(values) - } + #[test] + pub fn [< reduce_mean_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_mean_dim_naive::<$float, TestRuntime>(&Default::default()); + } - pub fn zero_vector(size: usize) -> Self { - let values = vec![N::new(0.0); size]; - Self::new_vector(values) - } + #[test] + pub fn [< reduce_argmax_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_argmax_dim_naive::<$float, TestRuntime>(&Default::default()); + } - pub fn with_line_size(mut self, line_size: u8) -> Self { - self.line_size = line_size; - self - } + #[test] + pub fn [< reduce_argmin_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_argmin_dim_naive::<$float, TestRuntime>(&Default::default()); + } + )* + } + }; } #[derive(Debug)] -pub struct TestCase { - pub input: TestTensorParts, - pub output: TestTensorParts, - pub expected: Vec, +pub struct TestCase { + pub shape: Vec, + pub stride: Vec, + pub reduce_dim: u32, + pub line_size: u8, pub cube_count: CubeCount, pub cube_dim: CubeDim, - pub sum_dim: u32, - pub reduce_lines: bool, } -impl TestCase { - pub fn new(input: TestTensorParts, output: TestTensorParts, expected: Vec) -> Self { - Self { - input, - output, - expected, - cube_count: CubeCount::Static(1, 1, 1), - cube_dim: CubeDim::new(32, 1, 1), - sum_dim: 0, - reduce_lines: false, - } +impl TestCase { + pub fn test_sum_dim_naive(&self, device: &R::Device) + where + F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_sum_dim(&input_values); + self.run_test::(device, input_values, expected_values) } - /// ASSUMPTION: line_size divide num_values exactly - pub fn sum_vector(num_values: usize, cube_size: u32, line_size: usize) -> Self + pub fn test_prod_dim_naive(&self, device: &R::Device) where - F: Float, + F: Float + CubeElement + std::fmt::Display, + R: Runtime, { - // Compute the sums on the cpu. - let values_per_sum = num_values / line_size; - let partial_sum = values_per_sum * (values_per_sum - 1) / 2; - let mut sums = vec![0; line_size]; - - #[allow(clippy::needless_range_loop)] - for k in 0..line_size { - sums[k] = partial_sum + values_per_sum * k; - } - let sums = sums.into_iter().map(|s| F::new(s as f32)).collect(); - - let mut test = TestCase::new( - // input - TestTensorParts::range_vector(num_values), - // output - TestTensorParts::zero_vector(line_size), - // expected - sums, - ); - test.cube_dim = CubeDim::new(cube_size, 1, 1); - test + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_prod_dim(&input_values); + self.run_test::(device, input_values, expected_values) } - pub fn run(self, device: &R::Device) + pub fn test_mean_dim_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_mean_dim(&input_values); + self.run_test::(device, input_values, expected_values) + } + + pub fn test_argmax_dim_naive(&self, device: &R::Device) + where + F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_argmax_dim(&input_values); + self.run_test::(device, input_values, expected_values) + } + + pub fn test_argmin_dim_naive(&self, device: &R::Device) + where + F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_argmin_dim(&input_values); + self.run_test::(device, input_values, expected_values) + } + + pub fn run_test( + &self, + device: &R::Device, + input_values: Vec, + expected_values: Vec, + ) where + I: Numeric + CubeElement + std::fmt::Display, + O: Numeric + CubeElement + std::fmt::Display, + R: Runtime, + K: ReduceNaiveInstruction, { let client = R::client(device); - if !client.properties().feature_enabled(Feature::Plane) { - // Can't execute the test. - return; - } - let input_handle = client.create(F::as_bytes(&self.input.values)); - let output_handle = client.create(F::as_bytes(&self.output.values)); + let input_handle = client.create(I::as_bytes(&input_values)); - let config = ReduceConfig { - line_size: self.input.line_size as u32, - max_num_planes: self.cube_dim.num_elems() - / client.properties().hardware_properties().plane_size_min, - }; + // Zero initialize a tensor with the same shape as input + // except for the `self.reduce_dim` axis where the shape is 1. + let output_handle = + client.create(O::as_bytes(&vec![O::from_int(0); expected_values.len()])); + let mut output_shape = self.shape.clone(); + output_shape[self.reduce_dim as usize] = 1; + let output_stride = self.output_stride(); unsafe { - let input_tensor = TensorArg::from_raw_parts::( + let input_tensor = TensorArg::from_raw_parts::( &input_handle, - &self.input.stride, - &self.input.shape, - self.input.line_size, + &self.stride, + &self.shape, + self.line_size, ); - let output_tensor = TensorArg::from_raw_parts::( + let output_tensor = TensorArg::from_raw_parts::( &output_handle, - &self.output.stride, - &self.output.shape, - self.output.line_size, + &output_stride, + &output_shape, + self.line_size, ); - if self.reduce_lines { - reduce_sum_lined::launch_unchecked::( - &client, - self.cube_count, - self.cube_dim, - input_tensor, - output_tensor, - config, - ); - } else { - reduce_sum::launch_unchecked::( - &client, - self.cube_count, - self.cube_dim, - input_tensor, - output_tensor, - config, - ); - } + naive_reduce_dim_kernel::launch_unchecked::( + &client, + self.cube_count.clone(), + self.cube_dim.clone(), + input_tensor, + output_tensor, + ScalarArg::new(self.reduce_dim.clone()), + ); } let binding = output_handle.binding(); let bytes = client.read_one(binding); - let output_values = F::from_bytes(&bytes); + let output_values = O::from_bytes(&bytes); + + assert_approx_equal_abs(output_values, &expected_values, 1e-7); + } + + fn cpu_sum_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![F::new(0.0); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + expected[output_index] += values[input_index]; + } + expected + } + + fn cpu_prod_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![F::new(1.0); self.num_output_values()]; + for value_index in 0..values.len() { + let output_index = self.to_output_index(value_index); + expected[output_index] *= values[value_index]; + } + expected + } + + fn cpu_mean_dim(&self, values: &[F]) -> Vec { + self.cpu_sum_dim(values) + .into_iter() + .map(|sum| sum / F::new(self.shape[self.reduce_dim as usize] as f32)) + .collect() + } + + fn cpu_argmax_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + let (best, _) = expected[output_index]; + let candidate = values[input_index]; + if candidate > best { + let coordinate = self.to_input_coordinate(input_index / self.line_size as usize); + expected[output_index] = (candidate, coordinate[self.reduce_dim as usize] as u32); + } + } + expected.into_iter().map(|(_, i)| i).collect() + } + + fn cpu_argmin_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + let (best, _) = expected[output_index]; + let candidate = values[input_index]; + if candidate < best { + let coordinate = self.to_input_coordinate(input_index / self.line_size as usize); + expected[output_index] = (candidate, coordinate[self.reduce_dim as usize] as u32); + } + } + expected.into_iter().map(|(_, i)| i).collect() + } + + fn num_output_values(&self) -> usize { + self.line_size as usize * self.shape.iter().product::() + / self.shape[self.reduce_dim as usize] + } + + fn to_output_index(&self, input_index: usize) -> usize { + let line_size = self.line_size as usize; + let mut coordinate = self.to_input_coordinate(input_index / line_size); + coordinate[self.reduce_dim as usize] = 0; + self.from_output_coordinate(coordinate) * line_size + input_index % line_size + } + + fn to_input_coordinate(&self, index: usize) -> Vec { + self.stride + .iter() + .zip(self.shape.iter()) + .map(|(stride, shape)| (index / stride) % shape) + .collect() + } + + fn from_output_coordinate(&self, coordinate: Vec) -> usize { + coordinate + .into_iter() + .zip(self.output_stride().iter()) + .map(|(c, s)| c * s) + .sum() + } + + fn output_stride(&self) -> Vec { + let dim_stride = self.stride[self.reduce_dim as usize]; + let dim_shape = self.shape[self.reduce_dim as usize]; + self.stride + .iter() + .map(|s| match s.cmp(&dim_stride) { + std::cmp::Ordering::Equal => 1, + std::cmp::Ordering::Greater => s / dim_shape, + std::cmp::Ordering::Less => *s, + }) + .collect() + } + + fn random_input_values(&self) -> Vec { + let size = self.shape.iter().product::() * self.line_size as usize; + + fn lcg(seed: &mut u64) -> f32 { + const A: u64 = 1664525; + const C: u64 = 1013904223; + const M: f64 = 2u64.pow(32) as f64; + + *seed = (A.wrapping_mul(*seed).wrapping_add(C)) % (1u64 << 32); + (*seed as f64 / M * 2.0 - 1.0) as f32 + } - assert_approx_equal_abs(output_values, &self.expected, 1e-9); + let mut seed = 123456789; // Not really important for testing. + (0..size).map(|_| F::new(lcg(&mut seed))).collect() } }