Skip to content

Commit

Permalink
Matmul: better error message (minor PR) (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Nov 27, 2024
1 parent 6de4938 commit 5296f55
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 15 deletions.
2 changes: 1 addition & 1 deletion crates/cubecl-linalg/src/matmul/components/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub trait MatmulKernel<I: Numeric, O: Numeric> {
/// Checks if the client can handle the features used in this computation
fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str>;
) -> Result<(), String>;

/// Create config for this matmul, given launch information
fn make_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl<EG: Numeric, ES: Numeric, GMM: global::Matmul<EG, ES>, S: SpanMatmul, C: Cu

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
GMM::check_availability::<R>(client)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<EG: Numeric, ES: Numeric, GMM: global::Matmul<EG, ES>, C: CubeDispatch> Mat

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
GMM::check_availability::<R>(client)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ where

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
SMM::check_availability::<R>(client)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ where

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
SMM::check_availability::<R>(client)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ where

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
TMM::check_availability::<R>(client)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ where

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
TMM::check_availability::<R>(client)
}

Expand Down
19 changes: 15 additions & 4 deletions crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ macro_rules! instruction {

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
check_availability::<I, O, R>(Self::M, Self::N, Self::K, client)
}

Expand Down Expand Up @@ -232,7 +232,7 @@ fn check_availability<I: Numeric, O: Numeric, R: Runtime>(
n: u32,
k: u32,
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
if !client.properties().feature_enabled(Feature::Cmma {
a: I::as_elem(),
b: I::as_elem(),
Expand All @@ -241,7 +241,14 @@ fn check_availability<I: Numeric, O: Numeric, R: Runtime>(
k: k as u8,
n: n as u8,
}) {
return Err("Cmma not supported.");
return Err(format!(
"Cmma on inputs {:?} and outputs {:?} with shape m={:?}, n={:?}, k={:?} not supported.",
I::as_elem(),
O::as_elem(),
m,
n,
k
));
}

if !(client
Expand All @@ -251,7 +258,11 @@ fn check_availability<I: Numeric, O: Numeric, R: Runtime>(
.properties()
.feature_enabled(Feature::Type(O::as_elem())))
{
return Err("Types not supported.");
return Err(format!(
"Types {:?} and/or {:?} not supported.",
I::as_elem(),
O::as_elem()
));
}

Ok(())
Expand Down
10 changes: 7 additions & 3 deletions crates/cubecl-linalg/src/matmul/components/tile/plane.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ impl<I: Numeric, O: Numeric, const M: u32, const N: u32, const K: u32> MatmulKer

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
if !client.properties().feature_enabled(Feature::Plane) {
return Err("Planes not supported.");
return Err("Plane operations not supported.".to_string());
}

if !(client
Expand All @@ -373,7 +373,11 @@ impl<I: Numeric, O: Numeric, const M: u32, const N: u32, const K: u32> MatmulKer
.properties()
.feature_enabled(Feature::Type(O::as_elem())))
{
return Err("Types not supported.");
return Err(format!(
"Types {:?} and/or {:?} not supported.",
I::as_elem(),
O::as_elem()
));
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub trait Algorithm<EG: Numeric> {

fn check_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Result<(), &str> {
) -> Result<(), String> {
Self::BatchMatmul::check_availability::<R>(client)
}

Expand Down

0 comments on commit 5296f55

Please sign in to comment.