diff --git a/crates/cubecl-common/src/benchmark.rs b/crates/cubecl-common/src/benchmark.rs index 57e837e6..c606e926 100644 --- a/crates/cubecl-common/src/benchmark.rs +++ b/crates/cubecl-common/src/benchmark.rs @@ -204,7 +204,7 @@ pub trait Benchmark { // Warmup let args = self.prepare(); - for _ in 0..10 { + for _ in 0..self.num_samples() { self.execute(args.clone()); } diff --git a/crates/cubecl-core/src/frontend/cmma.rs b/crates/cubecl-core/src/frontend/cmma.rs index 7b79c87d..a8e44d88 100644 --- a/crates/cubecl-core/src/frontend/cmma.rs +++ b/crates/cubecl-core/src/frontend/cmma.rs @@ -405,6 +405,57 @@ pub mod execute { } } +/// Store the matrix in the given array following the given stride and layout. +#[allow(unused_variables)] +pub fn cast(input: &Matrix) -> Matrix { + unexpanded!() +} + +/// Module containing the expand function for [store()]. +pub mod cast { + use super::*; + + /// Expand method of [store()]. + #[allow(unused_variables)] + pub fn expand( + context: &mut CubeContext, + input: MatrixExpand, + ) -> MatrixExpand { + let ident = input.ident; + + if core::any::TypeId::of::() == core::any::TypeId::of::() { + return MatrixExpand { + elem: input.elem, + ident, + _c: PhantomData, + }; + } + let input = *input.elem; + let input_mat = match input.kind { + ir::VariableKind::Matrix { id, mat, depth } => mat, + _ => unreachable!(), + }; + + let elem = context.create_matrix(ir::Matrix { + ident, + m: input_mat.m, + n: input_mat.n, + k: input_mat.k, + elem: O::as_elem(), + layout: MatrixLayout::Undefined, + }); + + let output = MatrixExpand { + ident, + elem, + _c: PhantomData, + }; + context.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem)); + + output + } +} + impl From for Operation { fn from(value: ir::CoopMma) -> Self { Operation::CoopMma(value) diff --git a/crates/cubecl-core/src/ir/cmma.rs b/crates/cubecl-core/src/ir/cmma.rs index 9588d3ef..152a0382 100644 --- a/crates/cubecl-core/src/ir/cmma.rs +++ b/crates/cubecl-core/src/ir/cmma.rs @@ -69,6 +69,8 @@ pub enum CoopMma { stride: Variable, layout: MatrixLayout, }, + /// Cast a fragment to another type. + Cast { input: Variable }, } impl Display for CoopMma { @@ -99,6 +101,9 @@ impl Display for CoopMma { "matrix_store({}, stride: {}, layout: {:?})", mat, stride, layout ), + CoopMma::Cast { input } => { + write!(f, "matrix_cast(input: {})", input) + } } } } diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index 8222c3cf..ec6f7236 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -312,6 +312,9 @@ impl ScopeProcessing { CoopMma::Store { stride, .. } => { sanitize_constant_scalar_ref_elem(stride, u32::as_elem()); } + CoopMma::Cast { .. } => { + // Nothing to do. + } }, }); self diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index c76454e9..e53475fe 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -87,6 +87,29 @@ pub fn kernel_simple_tf32(lhs: &Array, rhs: &Array, out: &mut Array< ); } +#[cube(launch)] +pub fn cast_matrix(input: &Array, out: &mut Array) { + let acc = unsafe { + cmma::Matrix::::uninitialized( + cmma::MatrixIdent::Accumulator, + 16, + 16, + 16, + cmma::MatrixLayout::Undefined, + ) + }; + cmma::load_with_layout(&acc, &input.to_slice(), 16, cmma::MatrixLayout::RowMajor); + + let output = cmma::cast::(&acc); + + cmma::store( + &mut out.to_slice_mut(), + &output, + 16, + cmma::MatrixLayout::RowMajor, + ); +} + pub fn test_simple_1( client: ComputeClient, cube_dimensions: CubeDim, @@ -151,6 +174,43 @@ pub fn test_simple_1( assert_eq!(expected, actual); } +pub fn test_cmma_cast_acc( + client: ComputeClient, + cube_dimensions: CubeDim, +) { + if !client.properties().feature_enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }) { + // We can't execute the test, skip. + return; + } + + let input: Vec = (0..256).map(|i| i as f32).collect(); + let input = client.create(f32::as_bytes(&input)); + let out = client.empty(core::mem::size_of::() * 256); + + unsafe { + cast_matrix::launch::( + &client, + CubeCount::Static(1, 1, 1), + cube_dimensions, + ArrayArg::from_raw_parts::(&input, 256, 1), + ArrayArg::from_raw_parts::(&out, 256, 1), + ) + }; + + let actual = client.read_one(out.binding()); + let actual = f16::from_bytes(&actual); + let expected: Vec = (0..256).map(|i| f16::from_f32(i as f32)).collect(); + + assert_eq!(actual, expected); +} + pub fn test_simple_tf32( client: ComputeClient, cube_dimensions: CubeDim, @@ -160,7 +220,7 @@ pub fn test_simple_tf32( b: Elem::Float(FloatKind::TF32), c: Elem::Float(FloatKind::F32), m: 16, - k: 8, + k: 16, n: 16, }) { // We can't execute the test, skip. @@ -243,5 +303,15 @@ macro_rules! testgen_cmma { cube_dimensions, ); } + + #[test] + fn test_cmma_cast_acc() { + let client = TestRuntime::client(&Default::default()); + let cube_dimensions = CubeDim::new(32, 1, 1); + cubecl_core::runtime_tests::cmma::test_cmma_cast_acc::( + client, + cube_dimensions, + ); + } }; } diff --git a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs index dc725b23..9222300d 100644 --- a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs +++ b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs @@ -59,8 +59,7 @@ impl WmmaCompiler> for CudaWmmaCompiler { fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations { let mut result: SupportedWmmaCombinations = vec![]; if arch.version >= WMMA_MINIMUM_VERSION { - // m n k - let tdims = vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)]; + let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)]; // Types fully supported. let types = vec![ ( diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index d11ea370..accb853e 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -203,6 +203,15 @@ for (uint i = 0; i < uint(8); ++i) {{ "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16); {output_ident}[{output_idx}] = {frag}[{frag_idx}]; +}} + " + ) + } + WmmaInstruction::Cast { input, output } => { + write!( + f, + "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ + {output}[elemIdx] = {input}[elemIdx]; }} " ) diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index cef5aafa..64fedb65 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -326,6 +326,10 @@ impl CppCompiler { .compile_matrix_layout(layout) .expect("Layout required for store instruction"), }), + gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast { + input: self.compile_variable(input), + output: out, + }), } } diff --git a/crates/cubecl-cpp/src/shared/mma.rs b/crates/cubecl-cpp/src/shared/mma.rs index a14cf777..56b82879 100644 --- a/crates/cubecl-cpp/src/shared/mma.rs +++ b/crates/cubecl-cpp/src/shared/mma.rs @@ -122,6 +122,11 @@ pub enum WmmaInstruction { stride: Variable, layout: FragmentLayout, }, + /// Cast + Cast { + input: Variable, + output: Variable, + }, } impl Display for FragmentLayout { @@ -287,6 +292,12 @@ pub mod wmma_api_base { ) } } + WmmaInstruction::Cast { input, output } => { + writeln!( + f, + "for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {input}.x[t]; }}" + ) + } } } } diff --git a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs index ddb92e39..fd723151 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs @@ -99,7 +99,7 @@ where stage_config.stage_dim(Ident::Out).tile_num_elements() / out_smem_line_size; let start = num_tile_lines * UNIT_POS_Y; - let mut out_smem = SharedMemory::::new_lined( + let mut out_smem = SharedMemory::::new_lined( num_tile_lines * stage_config.num_planes(), out_smem_line_size, ); @@ -109,7 +109,7 @@ where let accumulator = acc.index(accumulator_iter); let mut smem_slice = out_smem.slice_mut(start, start + num_tile_lines); TMM::read_accumulator(accumulator, &mut smem_slice, stage_config.to_tmm_config()); - SW::write::( + SW::write::( out, smem_slice.to_slice(), UNIT_POS_Y, diff --git a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs index 7aa4c8ff..a0d75f5d 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs @@ -123,7 +123,7 @@ where stage_config.stage_dim(Ident::Out).tile_num_elements() / out_smem_line_size; let start = num_tile_lines * UNIT_POS_Y; - let mut out_smem = SharedMemory::::new_lined( + let mut out_smem = SharedMemory::::new_lined( num_tile_lines * stage_config.num_planes(), out_smem_line_size, ); @@ -133,7 +133,7 @@ where let accumulator = acc.index(accumulator_iter); let mut smem_slice = out_smem.slice_mut(start, start + num_tile_lines); TMM::read_accumulator(accumulator, &mut smem_slice, stage_config.to_tmm_config()); - SW::write::( + SW::write::( out, smem_slice.to_slice(), UNIT_POS_Y, diff --git a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs index 935015df..cee95c40 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs @@ -223,7 +223,8 @@ fn read_accumulator( slice: &mut SliceMut>, #[comptime] n: u32, ) { - cmma::store(slice, out, n, cmma::MatrixLayout::RowMajor); + let acc = cmma::cast::(out); + cmma::store(slice, &acc, n, cmma::MatrixLayout::RowMajor); } fn check_availability( diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index b31ed108..470be4ba 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -235,6 +235,9 @@ impl Optimizer { visit_read(self, mat); visit_read(self, stride); } + CoopMma::Cast { input } => { + visit_read(self, input); + } } } diff --git a/crates/cubecl-spirv/src/cmma.rs b/crates/cubecl-spirv/src/cmma.rs index 1d03f2ad..bcbe4335 100644 --- a/crates/cubecl-spirv/src/cmma.rs +++ b/crates/cubecl-spirv/src/cmma.rs @@ -32,6 +32,7 @@ impl SpirvCompiler { layout, .. } => self.compile_store(mat, out, stride, layout), + CoopMma::Cast { input } => self.compile_cast(input, out), } } @@ -143,6 +144,23 @@ impl SpirvCompiler { self.store(mat_d.id, mat_d_id, None, vec![]).unwrap(); } + fn compile_cast(&mut self, input: core::Variable, output: core::Variable) { + let input = self.compile_variable(input); + let output = self.compile_variable(output); + + let input = self.matrix_var(&input).2; + let output = self.matrix_var(&output).2; + + let input_ty = self.item(&input).id(self); + let output_ty = self.item(&output).id(self); + + let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap(); + + let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap(); + + self.store(output.id, frag_new, None, vec![]).unwrap(); + } + fn matrix_var(&mut self, var: &Variable) -> (u16, u8, Matrix) { let (id, depth) = match var { Variable::CoopMatrix(id, depth, _) => (*id, *depth), diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 5a0a47ca..84cbd47e 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -39,4 +39,5 @@ mod tests_spirv { cubecl_core::testgen_all!(f32: [f16, flex32, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); cubecl_linalg::testgen_plane_mma!([f16, flex32, f32], f32); cubecl_linalg::testgen_tiling2d!([f16, flex32, f32, f64]); + cubecl_linalg::testgen_cmma_matmul!([f16]); } diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index 195dabd8..150faf20 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -55,8 +55,7 @@ struct MatmulBench { fn run(device: R::Device, strategy: matmul::Strategy) { let client = R::client(&device); - { - let (b, m, n, k) = (2, 4096, 4096, 4096); + for (b, m, n, k) in [(2, 4096, 4096, 4096), (2, 4096, 2040, 4096)] { let bench = MatmulBench:: { b, m,