Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast fragment #311

Merged
merged 9 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cubecl-common/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ pub trait Benchmark {
// Warmup
let args = self.prepare();

for _ in 0..10 {
for _ in 0..self.num_samples() {
self.execute(args.clone());
}

Expand Down
51 changes: 51 additions & 0 deletions crates/cubecl-core/src/frontend/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,57 @@ pub mod execute {
}
}

/// Store the matrix in the given array following the given stride and layout.
#[allow(unused_variables)]
pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
unexpanded!()
}

/// Module containing the expand function for [store()].
pub mod cast {
use super::*;

/// Expand method of [store()].
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, O: CubePrimitive>(
context: &mut CubeContext,
input: MatrixExpand<C>,
) -> MatrixExpand<O> {
let ident = input.ident;

if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
return MatrixExpand {
elem: input.elem,
ident,
_c: PhantomData,
};
}
let input = *input.elem;
let input_mat = match input.kind {
ir::VariableKind::Matrix { id, mat, depth } => mat,
_ => unreachable!(),
};

let elem = context.create_matrix(ir::Matrix {
ident,
m: input_mat.m,
n: input_mat.n,
k: input_mat.k,
elem: O::as_elem(),
layout: MatrixLayout::Undefined,
});

let output = MatrixExpand {
ident,
elem,
_c: PhantomData,
};
context.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));

output
}
}

impl From<ir::CoopMma> for Operation {
fn from(value: ir::CoopMma) -> Self {
Operation::CoopMma(value)
Expand Down
5 changes: 5 additions & 0 deletions crates/cubecl-core/src/ir/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ pub enum CoopMma {
stride: Variable,
layout: MatrixLayout,
},
/// Cast a fragment to another type.
Cast { input: Variable },
}

impl Display for CoopMma {
Expand Down Expand Up @@ -99,6 +101,9 @@ impl Display for CoopMma {
"matrix_store({}, stride: {}, layout: {:?})",
mat, stride, layout
),
CoopMma::Cast { input } => {
write!(f, "matrix_cast(input: {})", input)
}
}
}
}
3 changes: 3 additions & 0 deletions crates/cubecl-core/src/ir/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ impl ScopeProcessing {
CoopMma::Store { stride, .. } => {
sanitize_constant_scalar_ref_elem(stride, u32::as_elem());
}
CoopMma::Cast { .. } => {
// Nothing to do.
}
},
});
self
Expand Down
72 changes: 71 additions & 1 deletion crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,29 @@ pub fn kernel_simple_tf32(lhs: &Array<tf32>, rhs: &Array<tf32>, out: &mut Array<
);
}

#[cube(launch)]
pub fn cast_matrix(input: &Array<f32>, out: &mut Array<f16>) {
let acc = unsafe {
cmma::Matrix::<f32>::uninitialized(
cmma::MatrixIdent::Accumulator,
16,
16,
16,
cmma::MatrixLayout::Undefined,
)
};
cmma::load_with_layout(&acc, &input.to_slice(), 16, cmma::MatrixLayout::RowMajor);

let output = cmma::cast::<f32, f16>(&acc);

cmma::store(
&mut out.to_slice_mut(),
&output,
16,
cmma::MatrixLayout::RowMajor,
);
}

pub fn test_simple_1<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
Expand Down Expand Up @@ -151,6 +174,43 @@ pub fn test_simple_1<R: Runtime>(
assert_eq!(expected, actual);
}

pub fn test_cmma_cast_acc<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
) {
if !client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
}) {
// We can't execute the test, skip.
return;
}

let input: Vec<f32> = (0..256).map(|i| i as f32).collect();
let input = client.create(f32::as_bytes(&input));
let out = client.empty(core::mem::size_of::<f16>() * 256);

unsafe {
cast_matrix::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
cube_dimensions,
ArrayArg::from_raw_parts::<f32>(&input, 256, 1),
ArrayArg::from_raw_parts::<f16>(&out, 256, 1),
)
};

let actual = client.read_one(out.binding());
let actual = f16::from_bytes(&actual);
let expected: Vec<f16> = (0..256).map(|i| f16::from_f32(i as f32)).collect();

assert_eq!(actual, expected);
}

pub fn test_simple_tf32<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
Expand All @@ -160,7 +220,7 @@ pub fn test_simple_tf32<R: Runtime>(
b: Elem::Float(FloatKind::TF32),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 8,
k: 16,
n: 16,
}) {
// We can't execute the test, skip.
Expand Down Expand Up @@ -243,5 +303,15 @@ macro_rules! testgen_cmma {
cube_dimensions,
);
}

#[test]
fn test_cmma_cast_acc() {
let client = TestRuntime::client(&Default::default());
let cube_dimensions = CubeDim::new(32, 1, 1);
cubecl_core::runtime_tests::cmma::test_cmma_cast_acc::<TestRuntime>(
client,
cube_dimensions,
);
}
};
}
3 changes: 1 addition & 2 deletions crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ impl WmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations {
let mut result: SupportedWmmaCombinations = vec![];
if arch.version >= WMMA_MINIMUM_VERSION {
// m n k
let tdims = vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)];
let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)];
// Types fully supported.
let types = vec![
(
Expand Down
9 changes: 9 additions & 0 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ for (uint i = 0; i < uint(8); ++i) {{
"for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
{output_ident}[{output_idx}] = {frag}[{frag_idx}];
}}
"
)
}
WmmaInstruction::Cast { input, output } => {
write!(
f,
"for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
{output}[elemIdx] = {input}[elemIdx];
}}
"
)
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ impl<D: Dialect> CppCompiler<D> {
.compile_matrix_layout(layout)
.expect("Layout required for store instruction"),
}),
gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
input: self.compile_variable(input),
output: out,
}),
}
}

Expand Down
11 changes: 11 additions & 0 deletions crates/cubecl-cpp/src/shared/mma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ pub enum WmmaInstruction<D: Dialect> {
stride: Variable<D>,
layout: FragmentLayout<D>,
},
/// Cast
Cast {
input: Variable<D>,
output: Variable<D>,
},
}

impl<D: Dialect> Display for FragmentLayout<D> {
Expand Down Expand Up @@ -287,6 +292,12 @@ pub mod wmma_api_base {
)
}
}
WmmaInstruction::Cast { input, output } => {
writeln!(
f,
"for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {input}.x[t]; }}"
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ where
stage_config.stage_dim(Ident::Out).tile_num_elements() / out_smem_line_size;

let start = num_tile_lines * UNIT_POS_Y;
let mut out_smem = SharedMemory::<EA>::new_lined(
let mut out_smem = SharedMemory::<O>::new_lined(
num_tile_lines * stage_config.num_planes(),
out_smem_line_size,
);
Expand All @@ -109,7 +109,7 @@ where
let accumulator = acc.index(accumulator_iter);
let mut smem_slice = out_smem.slice_mut(start, start + num_tile_lines);
TMM::read_accumulator(accumulator, &mut smem_slice, stage_config.to_tmm_config());
SW::write::<EA, G>(
SW::write::<O, G>(
out,
smem_slice.to_slice(),
UNIT_POS_Y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ where
stage_config.stage_dim(Ident::Out).tile_num_elements() / out_smem_line_size;

let start = num_tile_lines * UNIT_POS_Y;
let mut out_smem = SharedMemory::<EA>::new_lined(
let mut out_smem = SharedMemory::<O>::new_lined(
num_tile_lines * stage_config.num_planes(),
out_smem_line_size,
);
Expand All @@ -133,7 +133,7 @@ where
let accumulator = acc.index(accumulator_iter);
let mut smem_slice = out_smem.slice_mut(start, start + num_tile_lines);
TMM::read_accumulator(accumulator, &mut smem_slice, stage_config.to_tmm_config());
SW::write::<EA, G>(
SW::write::<O, G>(
out,
smem_slice.to_slice(),
UNIT_POS_Y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ fn read_accumulator<O: Numeric, C: Numeric>(
slice: &mut SliceMut<Line<C>>,
#[comptime] n: u32,
) {
cmma::store(slice, out, n, cmma::MatrixLayout::RowMajor);
let acc = cmma::cast::<O, C>(out);
cmma::store(slice, &acc, n, cmma::MatrixLayout::RowMajor);
}

fn check_availability<I: Numeric, O: Numeric, R: Runtime>(
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-opt/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ impl Optimizer {
visit_read(self, mat);
visit_read(self, stride);
}
CoopMma::Cast { input } => {
visit_read(self, input);
}
}
}

Expand Down
18 changes: 18 additions & 0 deletions crates/cubecl-spirv/src/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
layout,
..
} => self.compile_store(mat, out, stride, layout),
CoopMma::Cast { input } => self.compile_cast(input, out),
}
}

Expand Down Expand Up @@ -143,6 +144,23 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
self.store(mat_d.id, mat_d_id, None, vec![]).unwrap();
}

fn compile_cast(&mut self, input: core::Variable, output: core::Variable) {
let input = self.compile_variable(input);
let output = self.compile_variable(output);

let input = self.matrix_var(&input).2;
let output = self.matrix_var(&output).2;

let input_ty = self.item(&input).id(self);
let output_ty = self.item(&output).id(self);

let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap();

let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap();

self.store(output.id, frag_new, None, vec![]).unwrap();
}

fn matrix_var(&mut self, var: &Variable) -> (u16, u8, Matrix) {
let (id, depth) = match var {
Variable::CoopMatrix(id, depth, _) => (*id, *depth),
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ mod tests_spirv {
cubecl_core::testgen_all!(f32: [f16, flex32, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]);
cubecl_linalg::testgen_plane_mma!([f16, flex32, f32], f32);
cubecl_linalg::testgen_tiling2d!([f16, flex32, f32, f64]);
cubecl_linalg::testgen_cmma_matmul!([f16]);
}
3 changes: 1 addition & 2 deletions crates/cubecl/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ struct MatmulBench<R: Runtime, E> {
fn run<R: Runtime, E: Float>(device: R::Device, strategy: matmul::Strategy) {
let client = R::client(&device);

{
let (b, m, n, k) = (2, 4096, 4096, 4096);
for (b, m, n, k) in [(2, 4096, 4096, 4096), (2, 4096, 2040, 4096)] {
let bench = MatmulBench::<R, E> {
b,
m,
Expand Down