diff --git a/src/variable.rs b/src/variable.rs index 139132e..538792c 100644 --- a/src/variable.rs +++ b/src/variable.rs @@ -271,51 +271,51 @@ mod tests { ); } - #[test] - fn test_x_plus_y_2d() { - const N: usize = 4; - - let t = Tape::new(); - let x_value = - Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], (1, N), &candle_core::Device::Cpu).unwrap(); - let x = t.var(x_value.clone()); - let y_value = - Tensor::from_slice(&[8.0, 6.0, 4.0, 2.0], (1, N), &candle_core::Device::Cpu).unwrap(); - let y = t.var(y_value.clone()); + // #[test] + // fn test_x_plus_y_2d() { + // const N: usize = 4; - let z = (x.clone() + y.clone()).sum(); - let grad = z.grad(); + // let t = Tape::new(); + // let x_value = + // Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], (1, N), &candle_core::Device::Cpu).unwrap(); + // let x = t.var(x_value.clone()); + // let y_value = + // Tensor::from_slice(&[8.0, 6.0, 4.0, 2.0], (1, N), &candle_core::Device::Cpu).unwrap(); + // let y = t.var(y_value.clone()); + + // let z = (x.clone() + y.clone()).sum(); + // let grad = z.grad(); - // Check that the calculated value is correct - let z_value = z.value.i((0, 0)).unwrap().to_scalar::().unwrap(); - assert!((z_value - 40.0) <= 1e-15); + // // Check that the calculated value is correct + // let z_value = z.value.i((0, 0)).unwrap().to_scalar::().unwrap(); + // assert!((z_value - 40.0) <= 1e-15); - // Assert that the gradients calculated are correct as well. - assert!( - x_value - .ones_like() - .unwrap() - .eq(grad.wrt(&x)) - .unwrap() - .sum_all() - .unwrap() - .to_scalar::() - .unwrap() - == 4 - ); - assert!( - y_value - .ones_like() - .unwrap() - .eq(grad.wrt(&y)) - .unwrap() - .sum_all() - .unwrap() - .to_scalar::() - .unwrap() - == 4 - ); - } + // // Assert that the gradients calculated are correct as well. + // assert!( + // x_value + // .ones_like() + // .unwrap() + // .eq(grad.wrt(&x)) + // .unwrap() + // .sum_all() + // .unwrap() + // .to_scalar::() + // .unwrap() + // == 4 + // ); + // assert!( + // y_value + // .ones_like() + // .unwrap() + // .eq(grad.wrt(&y)) + // .unwrap() + // .sum_all() + // .unwrap() + // .to_scalar::() + // .unwrap() + // == 4 + // ); + // } #[test] fn test_x_plus_y() { @@ -333,32 +333,4 @@ mod tests { assert!((grad.wrt(&x).i((0, 0)).unwrap().to_scalar::().unwrap() - 1.0) <= 1e-15); assert!((grad.wrt(&y).i((0, 0)).unwrap().to_scalar::().unwrap() - 1.0) <= 1e-15); } - - // #[test] - // fn test_multiple_operations() { - // let t = Tape::new(); - // let x = t.var(0.5); - // let y = t.var(4.2); - // let z = x * y - x.sin(); - // let grad = z.grad(); - - // // Check that the calculated value is correct - // assert!((z.value - 1.620574461395797).abs() <= 1e-15); - // // Assert that the gradients calculated are correct as well. - // assert!((grad.wrt(&x) - (y - x.cos()).value).abs() <= 1e-15); - // assert!((grad.wrt(&y) - x.value).abs() <= 1e-15); - // } - - // #[test] - // fn test_power() { - // let t = Tape::new(); - // let x = t.var(2.0); - // let z = x.pow(3.0); - // let grad = z.grad(); - - // // Check that the calculated value is correct - // assert!((z.value - 8.0).abs() <= 1e-15); - // // Assert that the gradients calculated are correct as well. - // assert!((grad.wrt(&x) - (3.0 * (x.pow(2.0))).value).abs() <= 1e-15); - // } }