Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex numbers #849

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis
gemm = { version = "0.15.4", default-features = false, optional = true }
rayon = { version = "1.7.0", optional = true }
libm = "0.2.7"
num-complex = {version = "0.4.0", optional = true}

[dev-dependencies]
tempfile = "3.3.0"
Expand Down Expand Up @@ -70,6 +71,8 @@ test-f64 = []
test-integrations = []
ci-check = ["cudarc?/ci-check"]

complex = ["dep:num-complex"]

[[bench]]
name = "batchnorm2d"
harness = false
Expand Down
112 changes: 111 additions & 1 deletion src/dtypes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module for data type related traits and structs. Contains things like [Unit], [Dtype], and [AMP].
//!
//! When the `f16` feature is enabled, this exports the [f16] type.

Check warning on line 3 in src/dtypes/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved link to `f16`
//!
//! # AMP
//!
Expand All @@ -14,6 +14,110 @@
#[cfg(feature = "f16")]
pub use half::f16;

#[cfg(feature = "complex")]
pub mod complex {
use core::ops::{Deref, DerefMut};

#[cfg(feature = "cuda")]
use cudarc::driver::{DeviceRepr, ValidAsZeroBits};
use num_complex::Complex32;
use num_traits::{FromPrimitive, ToPrimitive};

#[derive(PartialEq, Debug, Default, Clone, Copy)]
pub struct Complex(Complex32);
coreylowman marked this conversation as resolved.
Show resolved Hide resolved
impl Deref for Complex {
type Target = Complex32;

fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Complex {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
const fn c1() -> Complex {
Complex(Complex32 { re: 1.0, im: 0.0 })
}
impl Complex {
pub const ONE: Complex = c1();
pub fn new(r: f32, i: f32) -> Self {
Self(num_complex::Complex { re: r, im: i })
}
}
impl FromPrimitive for Complex {
fn from_i64(n: i64) -> Option<Self> {
Some(Complex(Complex32::from_i64(n)?))
}

fn from_u64(n: u64) -> Option<Self> {
Some(Complex(Complex32::from_u64(n)?))
}
}
impl ToPrimitive for Complex {
fn to_i64(&self) -> Option<i64> {
self.0.to_i64()
}

fn to_u64(&self) -> Option<u64> {
self.0.to_u64()
}
}

impl std::ops::Add<Self> for Complex {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0)
}
}
impl std::ops::Sub<Self> for Complex {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 - rhs.0)
}
}
impl std::ops::Mul<Self> for Complex {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
Self(self.0 * rhs.0)
}
}
impl std::ops::Div<Self> for Complex {
type Output = Self;

fn div(self, rhs: Self) -> Self::Output {
Self(self.0 / rhs.0)
}
}
impl std::ops::AddAssign for Complex {
fn add_assign(&mut self, rhs: Self) {
self.0.add_assign(rhs.0)
}
}
impl std::ops::SubAssign for Complex {
fn sub_assign(&mut self, rhs: Self) {
self.0.sub_assign(rhs.0)
}
}
impl std::ops::MulAssign for Complex {
fn mul_assign(&mut self, rhs: Self) {
self.0.mul_assign(rhs.0)
}
}
impl std::ops::DivAssign for Complex {
fn div_assign(&mut self, rhs: Self) {
self.0.div_assign(rhs.0)
}
}
#[cfg(feature = "cuda")]
unsafe impl ValidAsZeroBits for Complex {}
#[cfg(feature = "cuda")]
unsafe impl DeviceRepr for Complex {}
Comment on lines +115 to +118
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add these to cudarc in a PR there behind a feature flag, that should allow us to not need the wrapper type, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unaware that you controlled cudarc. That would work yes.

}

/// Represents a type where all 0 bits is a valid pattern.
#[cfg(not(feature = "cuda"))]
pub trait SafeZeros {}
Expand All @@ -30,7 +134,7 @@
+ Default
+ std::fmt::Debug
+ PartialEq
+ PartialOrd
// + PartialOrd
+ Send
+ Sync
+ std::marker::Unpin
Expand Down Expand Up @@ -65,6 +169,8 @@
unit!(bool, true);
#[cfg(feature = "f16")]
unit!(f16, f16::ONE);
#[cfg(feature = "complex")]
unit!(complex::Complex, complex::Complex::ONE);

/// Represents something that has a [Unit].
pub trait HasUnitType {
Expand Down Expand Up @@ -105,6 +211,8 @@
impl Dtype for usize {}
#[cfg(feature = "f16")]
impl Dtype for f16 {}
#[cfg(feature = "complex")]
impl Dtype for complex::Complex {}

/// Represents something that has a [Dtype].
pub trait HasDtype {
Expand All @@ -129,3 +237,5 @@
impl NotMixedPrecision for usize {}
#[cfg(feature = "f16")]
impl NotMixedPrecision for f16 {}
#[cfg(feature = "complex")]
impl NotMixedPrecision for complex::Complex {}
12 changes: 6 additions & 6 deletions src/tensor_ops/cmp/cpu_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,37 +48,37 @@ impl<Op: CmpOpCpuKernel<E>, E: Unit> ScalarCmpKernel<Op, E> for Cpu {
}
}

impl<E: Unit> CmpOpCpuKernel<E> for EqKernelOp {
impl<E: Unit + PartialOrd> CmpOpCpuKernel<E> for EqKernelOp {
fn func(lhs: E, rhs: E) -> bool {
lhs == rhs
}
}

impl<E: Unit> CmpOpCpuKernel<E> for NeKernelOp {
impl<E: Unit + PartialOrd> CmpOpCpuKernel<E> for NeKernelOp {
fn func(lhs: E, rhs: E) -> bool {
lhs != rhs
}
}

impl<E: Unit> CmpOpCpuKernel<E> for GtKernelOp {
impl<E: Unit + PartialOrd> CmpOpCpuKernel<E> for GtKernelOp {
fn func(lhs: E, rhs: E) -> bool {
lhs > rhs
}
}

impl<E: Unit> CmpOpCpuKernel<E> for GeKernelOp {
impl<E: Unit + PartialOrd> CmpOpCpuKernel<E> for GeKernelOp {
fn func(lhs: E, rhs: E) -> bool {
lhs >= rhs
}
}

impl<E: Unit> CmpOpCpuKernel<E> for LtKernelOp {
impl<E: Unit + PartialOrd> CmpOpCpuKernel<E> for LtKernelOp {
fn func(lhs: E, rhs: E) -> bool {
lhs < rhs
}
}

impl<E: Unit> CmpOpCpuKernel<E> for LeKernelOp {
impl<E: Unit + PartialOrd> CmpOpCpuKernel<E> for LeKernelOp {
fn func(lhs: E, rhs: E) -> bool {
lhs <= rhs
}
Expand Down
Loading