Skip to content

Commit

Permalink
added sin, cos and sub operations
Browse files Browse the repository at this point in the history
  • Loading branch information
kj3moraes committed Apr 6, 2024
1 parent ad495da commit ec8fb3b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Mul};
use std::ops::{Add, Mul, Sub};

use crate::variable::Variable;

Expand All @@ -14,6 +14,19 @@ impl<'a> Add for Variable<'a> {
}
}

// this is for self - rhs
impl<'a> Sub for Variable<'a> {
type Output = Variable<'a>;
fn sub(self, rhs: Self) -> Self::Output {
let position = self
.tape
.unwrap()
.push_binary(1.0, self.index, -1.0, rhs.index);
let new_value = self.value - rhs.value;
Variable::new(self.tape.unwrap(), position, new_value)
}
}

impl<'a> Mul for Variable<'a> {
type Output = Variable<'a>;
fn mul(self, rhs: Self) -> Self::Output {
Expand All @@ -26,12 +39,10 @@ impl<'a> Mul for Variable<'a> {
}
}

// Implementation for f64 * Variable
impl<'a> Mul<Variable<'a>> for f64 {
type Output = Variable<'a>;
fn mul(self, rhs: Variable<'a>) -> Self::Output {
let len = rhs.tape.unwrap().len();
let position = rhs.tape.unwrap().push_binary(0.0, len, self, rhs.index);
let position = rhs.tape.unwrap().push_unary(self, rhs.index);
let new_value = self * rhs.value;
Variable::new(rhs.tape.unwrap(), position, new_value)
}
Expand All @@ -40,11 +51,7 @@ impl<'a> Mul<Variable<'a>> for f64 {
impl<'a> Mul<Variable<'a>> for i32 {
type Output = Variable<'a>;
fn mul(self, rhs: Variable<'a>) -> Self::Output {
let len = rhs.tape.unwrap().len();
let position = rhs
.tape
.unwrap()
.push_binary(0.0, len, self as f64, rhs.index);
let position = rhs.tape.unwrap().push_unary(self as f64, rhs.index);
let new_value = self as f64 * rhs.value;
Variable::new(rhs.tape.unwrap(), position, new_value)
}
Expand Down
33 changes: 33 additions & 0 deletions src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ impl<'a> Variable<'a> {
new_value,
)
}

pub fn sin(&self) -> Variable {
Variable::new(
self.tape.unwrap(),
self.tape.unwrap().push_unary(self.value.cos(), self.index),
self.value.sin(),
)
}

pub fn cos(&self) -> Variable {
Variable::new(
self.tape.unwrap(),
self.tape
.unwrap()
.push_unary(-1.0 * self.value.sin(), self.index),
self.value.cos(),
)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -139,4 +157,19 @@ mod tests {
assert!((grad.wrt(&x) - 1.0).abs() <= 1e-15);
assert!((grad.wrt(&y) - 1.0).abs() <= 1e-15);
}

#[test]
fn test_multiple_operations() {
let t = Tape::new();
let x = t.var(0.5);
let y = t.var(4.2);
let z = x * y - x.sin();
let grad = z.grad();

// Check that the calculated value is correct
assert!((z.value - 1.620574461395797).abs() <= 1e-15);
// Assert that the gradients calculated are correct as well.
assert!((grad.wrt(&x) - (y - x.cos()).value).abs() <= 1e-15);
assert!((grad.wrt(&y) - x.value).abs() <= 1e-15);
}
}

0 comments on commit ec8fb3b

Please sign in to comment.