From 27e9006e4cd220b32cb20ecd0835bace7f249a84 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 26 Nov 2024 15:55:34 -0500 Subject: [PATCH 1/7] Cast fragment --- crates/cubecl-core/src/frontend/cmma.rs | 51 +++++++++++++ crates/cubecl-core/src/ir/cmma.rs | 5 ++ crates/cubecl-core/src/ir/processing.rs | 3 + crates/cubecl-core/src/runtime_tests/cmma.rs | 72 ++++++++++++++++++- .../src/hip/wmma/intrinsic_compiler.rs | 1 + crates/cubecl-cpp/src/shared/base.rs | 4 ++ crates/cubecl-cpp/src/shared/mma.rs | 11 +++ .../components/stage/multi_buffer/base.rs | 4 +- .../components/stage/single_buffer/base.rs | 4 +- .../src/matmul/components/tile/accelerated.rs | 3 +- .../matmul/kernels/matmul/algorithm/cmma.rs | 1 + crates/cubecl-opt/src/instructions.rs | 1 + 12 files changed, 154 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-core/src/frontend/cmma.rs b/crates/cubecl-core/src/frontend/cmma.rs index e465e75e..2c8ee133 100644 --- a/crates/cubecl-core/src/frontend/cmma.rs +++ b/crates/cubecl-core/src/frontend/cmma.rs @@ -405,6 +405,57 @@ pub mod execute { } } +/// Store the matrix in the given array following the given stride and layout. +#[allow(unused_variables)] +pub fn cast(input: &Matrix) -> Matrix { + unexpanded!() +} + +/// Module containing the expand function for [store()]. +pub mod cast { + use super::*; + + /// Expand method of [store()]. + #[allow(unused_variables)] + pub fn expand( + context: &mut CubeContext, + input: MatrixExpand, + ) -> MatrixExpand { + let ident = input.ident; + + if core::any::TypeId::of::() == core::any::TypeId::of::() { + return MatrixExpand { + elem: input.elem, + ident, + _c: PhantomData, + }; + } + let input = *input.elem; + let input_mat = match input.kind { + ir::VariableKind::Matrix { id, mat, depth } => mat, + _ => unreachable!(), + }; + + let elem = context.create_matrix(ir::Matrix { + ident, + m: input_mat.m, + n: input_mat.n, + k: input_mat.k, + elem: O::as_elem(), + layout: MatrixLayout::Undefined, + }); + + let output = MatrixExpand { + ident, + elem, + _c: PhantomData, + }; + context.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem)); + + output + } +} + impl From for Operation { fn from(value: ir::CoopMma) -> Self { Operation::CoopMma(value) diff --git a/crates/cubecl-core/src/ir/cmma.rs b/crates/cubecl-core/src/ir/cmma.rs index 9588d3ef..152a0382 100644 --- a/crates/cubecl-core/src/ir/cmma.rs +++ b/crates/cubecl-core/src/ir/cmma.rs @@ -69,6 +69,8 @@ pub enum CoopMma { stride: Variable, layout: MatrixLayout, }, + /// Cast a fragment to another type. + Cast { input: Variable }, } impl Display for CoopMma { @@ -99,6 +101,9 @@ impl Display for CoopMma { "matrix_store({}, stride: {}, layout: {:?})", mat, stride, layout ), + CoopMma::Cast { input } => { + write!(f, "matrix_cast(input: {})", input) + } } } } diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index 8222c3cf..ec6f7236 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -312,6 +312,9 @@ impl ScopeProcessing { CoopMma::Store { stride, .. } => { sanitize_constant_scalar_ref_elem(stride, u32::as_elem()); } + CoopMma::Cast { .. } => { + // Nothing to do. + } }, }); self diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index c76454e9..acd8103c 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -87,6 +87,29 @@ pub fn kernel_simple_tf32(lhs: &Array, rhs: &Array, out: &mut Array< ); } +#[cube(launch)] +pub fn cast_matrix(input: &Array, out: &mut Array) { + let acc = unsafe { + cmma::Matrix::::uninitialized( + cmma::MatrixIdent::Accumulator, + 16, + 16, + 16, + cmma::MatrixLayout::Undefined, + ) + }; + cmma::load_with_layout(&acc, &input.to_slice(), 16, cmma::MatrixLayout::RowMajor); + + let output = cmma::cast::(&acc); + + cmma::store( + &mut out.to_slice_mut(), + &output, + 16, + cmma::MatrixLayout::RowMajor, + ); +} + pub fn test_simple_1( client: ComputeClient, cube_dimensions: CubeDim, @@ -151,7 +174,7 @@ pub fn test_simple_1( assert_eq!(expected, actual); } -pub fn test_simple_tf32( +pub fn test_cmma_cast_acc( client: ComputeClient, cube_dimensions: CubeDim, ) { @@ -167,6 +190,43 @@ pub fn test_simple_tf32( return; } + let input: Vec = (0..256).map(|i| i as f32).collect(); + let input = client.create(f32::as_bytes(&input)); + let out = client.empty(core::mem::size_of::() * 256); + + unsafe { + cast_matrix::launch::( + &client, + CubeCount::Static(1, 1, 1), + cube_dimensions, + ArrayArg::from_raw_parts::(&input, 256, 1), + ArrayArg::from_raw_parts::(&out, 256, 1), + ) + }; + + let actual = client.read_one(out.binding()); + let actual = f16::from_bytes(&actual); + let expected: Vec = (0..256).map(|i| f16::from_f32(i as f32)).collect(); + + assert_eq!(actual, expected); +} + +pub fn test_simple_tf32( + client: ComputeClient, + cube_dimensions: CubeDim, +) { + if !client.properties().feature_enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }) { + // We can't execute the test, skip. + return; + } + let lhs: Vec = (0..128).map(|i| (i as f32)).collect(); let rhs: Vec = (0..128).map(|i| ((i % 8) as f32)).collect(); @@ -243,5 +303,15 @@ macro_rules! testgen_cmma { cube_dimensions, ); } + + #[test] + fn test_cmma_cast_acc() { + let client = TestRuntime::client(&Default::default()); + let cube_dimensions = CubeDim::new(32, 1, 1); + cubecl_core::runtime_tests::cmma::test_cmma_cast_acc::( + client, + cube_dimensions, + ); + } }; } diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index d11ea370..e4415774 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -207,6 +207,7 @@ for (uint i = 0; i < uint(8); ++i) {{ " ) } + WmmaInstruction::Cast { input, output } => todo!(), } } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index cef5aafa..64fedb65 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -326,6 +326,10 @@ impl CppCompiler { .compile_matrix_layout(layout) .expect("Layout required for store instruction"), }), + gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast { + input: self.compile_variable(input), + output: out, + }), } } diff --git a/crates/cubecl-cpp/src/shared/mma.rs b/crates/cubecl-cpp/src/shared/mma.rs index a14cf777..56b82879 100644 --- a/crates/cubecl-cpp/src/shared/mma.rs +++ b/crates/cubecl-cpp/src/shared/mma.rs @@ -122,6 +122,11 @@ pub enum WmmaInstruction { stride: Variable, layout: FragmentLayout, }, + /// Cast + Cast { + input: Variable, + output: Variable, + }, } impl Display for FragmentLayout { @@ -287,6 +292,12 @@ pub mod wmma_api_base { ) } } + WmmaInstruction::Cast { input, output } => { + writeln!( + f, + "for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {input}.x[t]; }}" + ) + } } } } diff --git a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs index b7a03053..e4e60bb9 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/multi_buffer/base.rs @@ -99,7 +99,7 @@ where stage_config.stage_dim(Ident::Out).tile_num_elements() / out_smem_line_size; let start = num_tile_lines * UNIT_POS_Y; - let mut out_smem = SharedMemory::::new_lined( + let mut out_smem = SharedMemory::::new_lined( num_tile_lines * stage_config.num_planes(), out_smem_line_size, ); @@ -109,7 +109,7 @@ where let accumulator = acc.index(accumulator_iter); let mut smem_slice = out_smem.slice_mut(start, start + num_tile_lines); TMM::read_accumulator(accumulator, &mut smem_slice, stage_config.to_tmm_config()); - SW::write::( + SW::write::( out, smem_slice.to_slice(), UNIT_POS_Y, diff --git a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs index 48fc0770..50d982ac 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/single_buffer/base.rs @@ -111,7 +111,7 @@ where stage_config.stage_dim(Ident::Out).tile_num_elements() / out_smem_line_size; let start = num_tile_lines * UNIT_POS_Y; - let mut out_smem = SharedMemory::::new_lined( + let mut out_smem = SharedMemory::::new_lined( num_tile_lines * stage_config.num_planes(), out_smem_line_size, ); @@ -121,7 +121,7 @@ where let accumulator = acc.index(accumulator_iter); let mut smem_slice = out_smem.slice_mut(start, start + num_tile_lines); TMM::read_accumulator(accumulator, &mut smem_slice, stage_config.to_tmm_config()); - SW::write::( + SW::write::( out, smem_slice.to_slice(), UNIT_POS_Y, diff --git a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs index 97068e8b..fabf58d4 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs @@ -207,7 +207,8 @@ fn read_accumulator( slice: &mut SliceMut>, #[comptime] n: u32, ) { - cmma::store(slice, out, n, cmma::MatrixLayout::RowMajor); + let acc = cmma::cast::(out); + cmma::store(slice, &acc, n, cmma::MatrixLayout::RowMajor); } fn check_availability( diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs index 17bb1f43..3648401e 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs @@ -24,6 +24,7 @@ impl base::Algorithm for Cmma { type EG = EG; type ES = half::f16; type EA = f32; + // type EA = half::f16; type TileMatmul = Accelerated16x16x16; diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index b31ed108..6837931f 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -235,6 +235,7 @@ impl Optimizer { visit_read(self, mat); visit_read(self, stride); } + CoopMma::Cast { input } => todo!(), } } From bc35ac7cef24c0d3cd64e1f6a5ccf89a394700b2 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 26 Nov 2024 17:29:15 -0500 Subject: [PATCH 2/7] Test --- crates/cubecl-common/src/benchmark.rs | 2 +- .../cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs | 3 +-- crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/cubecl-common/src/benchmark.rs b/crates/cubecl-common/src/benchmark.rs index 57e837e6..c606e926 100644 --- a/crates/cubecl-common/src/benchmark.rs +++ b/crates/cubecl-common/src/benchmark.rs @@ -204,7 +204,7 @@ pub trait Benchmark { // Warmup let args = self.prepare(); - for _ in 0..10 { + for _ in 0..self.num_samples() { self.execute(args.clone()); } diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs index 33054267..89991b7e 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs @@ -12,7 +12,7 @@ use crate::matmul::components::{batch, global}; use super::base; type Stage = stage::S8x8x2; -type Dispatch = batch::SwizzleTransposedDispatch<2>; +type Dispatch = batch::TransposedDispatch; pub struct Cmma { pub _eg: PhantomData, @@ -24,7 +24,6 @@ impl base::Algorithm for Cmma { type EG = EG; type ES = half::f16; type EA = f32; - // type EA = half::f16; type TileMatmul = Accelerated16x16x16; diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs index 8b7f80ad..d1c91021 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs @@ -156,7 +156,7 @@ pub(crate) fn matmul_cube_preparation> let cube_count = D::cube_count(&problem); let advanced_config = AdvancedConfig { - lhs_tiling_order: matmul::components::stage::TilingOrderConfig::ColMajor, + lhs_tiling_order: matmul::components::stage::TilingOrderConfig::RowMajor, rhs_tiling_order: matmul::components::stage::TilingOrderConfig::RowMajor, enforced_tile_layout: (None, None), }; From 14e6d5a38506cf0336cd8118ee51bfacf2f0d5d9 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 27 Nov 2024 08:47:50 -0500 Subject: [PATCH 3/7] WIP --- .../src/hip/wmma/intrinsic_compiler.rs | 10 +++++++++- crates/cubecl-opt/src/instructions.rs | 4 +++- crates/cubecl-spirv/src/cmma.rs | 16 ++++++++++++++++ crates/cubecl-wgpu/src/lib.rs | 1 + 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index e4415774..accb853e 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -207,7 +207,15 @@ for (uint i = 0; i < uint(8); ++i) {{ " ) } - WmmaInstruction::Cast { input, output } => todo!(), + WmmaInstruction::Cast { input, output } => { + write!( + f, + "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ + {output}[elemIdx] = {input}[elemIdx]; +}} + " + ) + } } } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 6837931f..470be4ba 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -235,7 +235,9 @@ impl Optimizer { visit_read(self, mat); visit_read(self, stride); } - CoopMma::Cast { input } => todo!(), + CoopMma::Cast { input } => { + visit_read(self, input); + } } } diff --git a/crates/cubecl-spirv/src/cmma.rs b/crates/cubecl-spirv/src/cmma.rs index 1d03f2ad..e4a4351d 100644 --- a/crates/cubecl-spirv/src/cmma.rs +++ b/crates/cubecl-spirv/src/cmma.rs @@ -32,6 +32,7 @@ impl SpirvCompiler { layout, .. } => self.compile_store(mat, out, stride, layout), + CoopMma::Cast { input } => self.compile_cast(out, input), } } @@ -143,6 +144,21 @@ impl SpirvCompiler { self.store(mat_d.id, mat_d_id, None, vec![]).unwrap(); } + fn compile_cast(&mut self, input: core::Variable, output: core::Variable) { + let input = self.compile_variable(input); + let output = self.compile_variable(output); + + let input = self.matrix_var(&input).2; + let output = self.matrix_var(&output).2; + + let result_ty = self.item(&output); + let ty = result_ty.id(self); + let fragment_id = self.load(ty, None, input.id, None, vec![]).unwrap(); + let frag_new = self.f_convert(ty, None, fragment_id).unwrap(); + + self.store(output.id, frag_new, None, vec![]).unwrap(); + } + fn matrix_var(&mut self, var: &Variable) -> (u16, u8, Matrix) { let (id, depth) = match var { Variable::CoopMatrix(id, depth, _) => (*id, *depth), diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 5a0a47ca..84cbd47e 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -39,4 +39,5 @@ mod tests_spirv { cubecl_core::testgen_all!(f32: [f16, flex32, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); cubecl_linalg::testgen_plane_mma!([f16, flex32, f32], f32); cubecl_linalg::testgen_tiling2d!([f16, flex32, f32, f64]); + cubecl_linalg::testgen_cmma_matmul!([f16]); } From 2b097289181275686bed6f3bf0601202b10e0037 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 27 Nov 2024 09:18:48 -0500 Subject: [PATCH 4/7] WIP --- crates/cubecl-core/src/runtime_tests/cmma.rs | 6 +++--- crates/cubecl-spirv/src/cmma.rs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index acd8103c..64f95263 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -179,11 +179,11 @@ pub fn test_cmma_cast_acc( cube_dimensions: CubeDim, ) { if !client.properties().feature_enabled(Feature::Cmma { - a: Elem::Float(FloatKind::TF32), - b: Elem::Float(FloatKind::TF32), + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), c: Elem::Float(FloatKind::F32), m: 16, - k: 8, + k: 16, n: 16, }) { // We can't execute the test, skip. diff --git a/crates/cubecl-spirv/src/cmma.rs b/crates/cubecl-spirv/src/cmma.rs index e4a4351d..a8f2be9b 100644 --- a/crates/cubecl-spirv/src/cmma.rs +++ b/crates/cubecl-spirv/src/cmma.rs @@ -151,9 +151,10 @@ impl SpirvCompiler { let input = self.matrix_var(&input).2; let output = self.matrix_var(&output).2; - let result_ty = self.item(&output); + let result_ty = self.item(&input); let ty = result_ty.id(self); let fragment_id = self.load(ty, None, input.id, None, vec![]).unwrap(); + let frag_new = self.f_convert(ty, None, fragment_id).unwrap(); self.store(output.id, frag_new, None, vec![]).unwrap(); From d184a73644e376df9807bbfe3d71cb9d7808124e Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 27 Nov 2024 09:59:15 -0500 Subject: [PATCH 5/7] Fix --- crates/cubecl-spirv/src/cmma.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/crates/cubecl-spirv/src/cmma.rs b/crates/cubecl-spirv/src/cmma.rs index a8f2be9b..bcbe4335 100644 --- a/crates/cubecl-spirv/src/cmma.rs +++ b/crates/cubecl-spirv/src/cmma.rs @@ -32,7 +32,7 @@ impl SpirvCompiler { layout, .. } => self.compile_store(mat, out, stride, layout), - CoopMma::Cast { input } => self.compile_cast(out, input), + CoopMma::Cast { input } => self.compile_cast(input, out), } } @@ -151,11 +151,12 @@ impl SpirvCompiler { let input = self.matrix_var(&input).2; let output = self.matrix_var(&output).2; - let result_ty = self.item(&input); - let ty = result_ty.id(self); - let fragment_id = self.load(ty, None, input.id, None, vec![]).unwrap(); + let input_ty = self.item(&input).id(self); + let output_ty = self.item(&output).id(self); - let frag_new = self.f_convert(ty, None, fragment_id).unwrap(); + let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap(); + + let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap(); self.store(output.id, frag_new, None, vec![]).unwrap(); } From 3be83c6b17a069fe22b2c2297426c1e801488c16 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 27 Nov 2024 14:35:10 -0500 Subject: [PATCH 6/7] Fix --- crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs | 3 +-- crates/cubecl/benches/matmul.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs index dc725b23..d4d7f0b2 100644 --- a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs +++ b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs @@ -59,8 +59,7 @@ impl WmmaCompiler> for CudaWmmaCompiler { fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations { let mut result: SupportedWmmaCombinations = vec![]; if arch.version >= WMMA_MINIMUM_VERSION { - // m n k - let tdims = vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)]; + let tdims = vec![(16, 16, 16), (32, 16, 8), (8, 16, 32), (32, 8, 16)]; // Types fully supported. let types = vec![ ( diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index 195dabd8..150faf20 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -55,8 +55,7 @@ struct MatmulBench { fn run(device: R::Device, strategy: matmul::Strategy) { let client = R::client(&device); - { - let (b, m, n, k) = (2, 4096, 4096, 4096); + for (b, m, n, k) in [(2, 4096, 4096, 4096), (2, 4096, 2040, 4096)] { let bench = MatmulBench:: { b, m, From e9c0d44933d0e6ecd1f1b2facac1568b4957a143 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 27 Nov 2024 15:53:19 -0500 Subject: [PATCH 7/7] Fix --- crates/cubecl-core/src/runtime_tests/cmma.rs | 4 ++-- crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index 64f95263..e53475fe 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -216,8 +216,8 @@ pub fn test_simple_tf32( cube_dimensions: CubeDim, ) { if !client.properties().feature_enabled(Feature::Cmma { - a: Elem::Float(FloatKind::F16), - b: Elem::Float(FloatKind::F16), + a: Elem::Float(FloatKind::TF32), + b: Elem::Float(FloatKind::TF32), c: Elem::Float(FloatKind::F32), m: 16, k: 16, diff --git a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs index d4d7f0b2..9222300d 100644 --- a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs +++ b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs @@ -59,7 +59,7 @@ impl WmmaCompiler> for CudaWmmaCompiler { fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations { let mut result: SupportedWmmaCombinations = vec![]; if arch.version >= WMMA_MINIMUM_VERSION { - let tdims = vec![(16, 16, 16), (32, 16, 8), (8, 16, 32), (32, 8, 16)]; + let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)]; // Types fully supported. let types = vec![ (