Skip to content

Commit

Permalink
Vector operations (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
kj3moraes authored Apr 12, 2024
1 parent faabb22 commit a360c8e
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 117 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
name = "minigrad"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = "0.4.1"
9 changes: 5 additions & 4 deletions src/grad.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use crate::variable::Variable;
use candle_core::Tensor;

#[derive(Debug)]
pub struct Gradient {
derivatives: Vec<f64>,
pub(crate) derivatives: Vec<Tensor>,
}

impl Gradient {
pub fn from(derivatives: Vec<f64>) -> Self {
pub fn from(derivatives: Vec<Tensor>) -> Self {
Self { derivatives }
}

pub fn wrt(&self, var: &Variable) -> f64 {
self.derivatives[var.index]
pub fn wrt(&self, var: &Variable) -> &Tensor {
&self.derivatives[var.index]
}
}
77 changes: 47 additions & 30 deletions src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,75 @@
use std::ops::{Add, Mul, Sub};
use candle_core::{DType, Device, Shape, Tensor};

use crate::tape::convert_to_tensor;
use crate::variable::Variable;
use std::ops::{Add, Mul, Sub};

impl<'a> Add for Variable<'a> {
type Output = Variable<'a>;
fn add(self, rhs: Self) -> Self::Output {
let position = self
.tape
.unwrap()
.push_binary(1.0, self.index, 1.0, rhs.index);
let new_value = self.value + rhs.value;
Variable::new(self.tape.unwrap(), position, new_value)
let new_value = (&self.value + &rhs.value).unwrap();
let n = rhs.value.shape().dims()[1];
println!("THe second dimension sizie is {}", n);
let position = self.tape.unwrap().push_binary(
Tensor::from_slice(&[1.0], (1, 1), &candle_core::Device::Cpu).unwrap(),
self.index,
Tensor::eye(n, DType::F64, &Device::Cpu).unwrap(),
rhs.index,
new_value.shape().clone(),
);
Variable::new_tensor(self.tape.unwrap(), position, new_value)
}
}

// this is for self - rhs
impl<'a> Sub for Variable<'a> {
type Output = Variable<'a>;
fn sub(self, rhs: Self) -> Self::Output {
let position = self
.tape
.unwrap()
.push_binary(1.0, self.index, -1.0, rhs.index);
let new_value = self.value - rhs.value;
Variable::new(self.tape.unwrap(), position, new_value)
let new_value = (&self.value - &rhs.value).unwrap();
let position = self.tape.unwrap().push_binary(
rhs.value.t().unwrap().ones_like().unwrap(),
self.index,
(-1.0 * self.value.t().unwrap().ones_like().unwrap()).unwrap(),
rhs.index,
new_value.shape().clone(),
);
Variable::new_tensor(self.tape.unwrap(), position, new_value)
}
}

impl<'a> Mul for Variable<'a> {
type Output = Variable<'a>;
fn mul(self, rhs: Self) -> Self::Output {
let position = self
.tape
.unwrap()
.push_binary(rhs.value, self.index, self.value, rhs.index);
let new_value = self.value * rhs.value;
Variable::new(self.tape.unwrap(), position, new_value)
let new_value = self.value.matmul(&rhs.value).unwrap();
let position = self.tape.unwrap().push_binary(
rhs.value.t().unwrap().clone(),
self.index,
self.value.t().unwrap().clone(),
rhs.index,
new_value.shape().clone(),
);
Variable::new_tensor(self.tape.unwrap(), position, new_value)
}
}

impl<'a> Mul<Variable<'a>> for f64 {
type Output = Variable<'a>;
fn mul(self, rhs: Variable<'a>) -> Self::Output {
let position = rhs.tape.unwrap().push_unary(self, rhs.index);
let new_value = self * rhs.value;
Variable::new(rhs.tape.unwrap(), position, new_value)
let new_value = (self * rhs.value).unwrap();
let position = rhs.tape.unwrap().push_unary(
convert_to_tensor(self),
rhs.index,
new_value.shape().clone(),
);
Variable::new_tensor(rhs.tape.unwrap(), position, new_value)
}
}

impl<'a> Mul<Variable<'a>> for i32 {
type Output = Variable<'a>;
fn mul(self, rhs: Variable<'a>) -> Self::Output {
let position = rhs.tape.unwrap().push_unary(self as f64, rhs.index);
let new_value = self as f64 * rhs.value;
Variable::new(rhs.tape.unwrap(), position, new_value)
}
}
// impl<'a> Mul<Variable<'a>> for i32 {
// type Output = Variable<'a>;
// fn mul(self, rhs: Variable<'a>) -> Self::Output {
// let position = rhs.tape.unwrap().push_unary(self as f64, rhs.index);
// let new_value = self as f64 * rhs.value;
// Variable::new(rhs.tape.unwrap(), position, new_value)
// }
// }
55 changes: 43 additions & 12 deletions src/tape.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
use crate::variable::Variable;
use candle_core::{Device, Shape, Tensor};
use std::{cell::RefCell, fmt::Debug};

#[derive(Clone, Copy, Debug)]
pub fn convert_to_tensor(value: f64) -> Tensor {
Tensor::from_slice(&[value], (1, 1), &Device::Cpu).unwrap()
}

#[derive(Clone, Debug)]
pub(crate) struct Node {
pub(crate) weight: [f64; 2],
pub(crate) weight: [Tensor; 2],
pub(crate) deps: [usize; 2],

pub(crate) is_leaf: bool,
/// shape of the Variable this node symbolizes in the
pub(crate) shape: Shape,
}

impl Node {
pub fn from(weight: [f64; 2], deps: [usize; 2]) -> Self {
Self { weight, deps }
pub fn from(weight: [Tensor; 2], deps: [usize; 2], shape: Shape, is_leaf: bool) -> Self {
Self {
weight,
deps,
shape,
is_leaf,
}
}
}

Expand All @@ -31,28 +45,45 @@ impl Tape {
self.nodes.borrow().len()
}

pub fn var(&self, value: f64) -> Variable {
Variable::new(&self, self.push_leaf(), value)
pub fn var(&self, value: Tensor) -> Variable {
Variable::new_tensor(&self, self.push_leaf(value.shape().clone()), value)
}

pub fn push_leaf(&self) -> usize {
pub fn push_leaf(&self, shape: Shape) -> usize {
let mut nodes = self.nodes.borrow_mut();
let len = nodes.len();
nodes.push(Node::from([0.0, 0.0], [len, len]));
nodes.push(Node::from(
[convert_to_tensor(0.0), convert_to_tensor(0.0)],
[len, len],
shape,
true,
));
len
}

pub fn push_unary(&self, weight: f64, pos: usize) -> usize {
pub fn push_unary(&self, weight: Tensor, pos: usize, shape: Shape) -> usize {
let mut nodes = self.nodes.borrow_mut();
let len = nodes.len();
nodes.push(Node::from([weight, 0.0], [pos, len]));
nodes.push(Node::from(
[weight, convert_to_tensor(0.0)],
[pos, len],
shape,
false,
));
len
}

pub fn push_binary(&self, weight0: f64, pos0: usize, weight1: f64, pos1: usize) -> usize {
pub fn push_binary(
&self,
weight0: Tensor,
pos0: usize,
weight1: Tensor,
pos1: usize,
shape: Shape,
) -> usize {
let mut nodes = self.nodes.borrow_mut();
let len = nodes.len();
nodes.push(Node::from([weight0, weight1], [pos0, pos1]));
nodes.push(Node::from([weight0, weight1], [pos0, pos1], shape, false));
len
}
}
Loading

0 comments on commit a360c8e

Please sign in to comment.