Skip to content

Commit

Permalink
chore/layer/nll: fmt, cleanup, asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
drahnr committed Sep 12, 2022
1 parent 3da2df8 commit 290c0b7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions coaster/tests/shared_memory_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use coaster as co;

#[cfg(test)]
mod shared_memory_spec {
use super::co::prelude::*;
use super::co::tensor::Error;
#[cfg(features = "cuda")]
use super::co::frameworks::native::flatbox::FlatBox;
use super::co::prelude::*;
use super::co::tensor::Error;

#[cfg(features = "cuda")]
fn write_to_memory<T: Copy>(mem: &mut FlatBox, data: &[T]) {
Expand Down
2 changes: 1 addition & 1 deletion juice/src/layers/loss/negative_log_likelihood.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,4 @@ impl Into<LayerType> for NegativeLogLikelihoodConfig {
fn into(self) -> LayerType {
LayerType::NegativeLogLikelihood(self)
}
}
}
31 changes: 15 additions & 16 deletions juice/tests/layer_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ mod layer_spec {
.is_err());
}

use juice::layers::SequentialConfig;
use juice::layers::NegativeLogLikelihoodConfig;
use juice::layers::SequentialConfig;

#[test]
fn nll_basic() {
Expand All @@ -401,19 +401,15 @@ mod layer_spec {
let nll_layer_cfg = NegativeLogLikelihoodConfig { num_classes: 10 };
let nll_cfg = LayerConfig::new("nll", nll_layer_cfg);
classifier_cfg.add_layer(nll_cfg);
let mut network = Layer::from_config(
native_backend.clone(),
&LayerConfig::new("foo", classifier_cfg),
);
let labels_data = (0..(BATCH_SIZE * KLASS_COUNT))
.into_iter()
.map(|x| x as f32)
.collect::<Vec<f32>>();
let mut network = Layer::from_config(native_backend.clone(), &LayerConfig::new("foo", classifier_cfg));
let desc = [BATCH_SIZE, KLASS_COUNT];
let desc: &[usize] = &desc[..];
let mut input = SharedTensor::<f32>::new(&desc);
let mem = input.write_only(native_backend.device()).unwrap();
let input_data = (0..(KLASS_COUNT * BATCH_SIZE)).into_iter().map(|x| x as f32 * 3.77).collect::<Vec<f32>>();
let input_data = (0..(KLASS_COUNT * BATCH_SIZE))
.into_iter()
.map(|x| x as f32 * 3.77)
.collect::<Vec<f32>>();
let input_data = &input_data[..];
juice::util::write_to_memory(mem, input_data);

Expand All @@ -435,11 +431,14 @@ mod layer_spec {
std::sync::Arc::new(std::sync::RwLock::new(labels)),
];

let output = network.forward(input.as_slice());

let x = output[0].read().unwrap();
dbg!(&x);
let out = x.read(native_backend.device()).unwrap();
dbg!(out.as_slice::<f32>());
let out = network.forward(input.as_slice());
assert_eq!(out.len(), 1);
let out = &out[0];
let out = out.read().unwrap();
assert_eq!(out.desc().dims(), &vec![BATCH_SIZE, 1]);
let out = out.read(native_backend.device()).unwrap();
let out_mem = out.as_slice::<f32>();
assert_eq!(out_mem.len(), BATCH_SIZE);
assert!(out_mem[0] < 0_f32);
}
}

0 comments on commit 290c0b7

Please sign in to comment.