Skip to content

Commit

Permalink
fix some problems that appeared after I resolved merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmfern committed Aug 16, 2022
1 parent f4988cd commit c261a50
Show file tree
Hide file tree
Showing 14 changed files with 640 additions and 1,827 deletions.
636 changes: 19 additions & 617 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "intricate"
version = "0.2.2"
version = "0.3.0"
edition = "2021"
license = "MIT"
authors = ["Gabriel Miranda"]
Expand All @@ -12,8 +12,6 @@ readme = "README.md"
[dependencies]
rayon = "1.5.3"
rand = "0.8.5"
wgpu = "0.13.1"
async-trait = "0.1.56"
savefile-derive="0.10"
savefile="0.10"
opencl3="0.8.1"
Expand Down
51 changes: 1 addition & 50 deletions examples/xor/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,55 +50,6 @@ fn main() -> () {
epochs: 5000,
},
)
.await
.unwrap();

// for saving Intricate uses the 'savefile' crate
// that simply needs to call the 'save_file' function to the path you want
// for the layers in the model and then load the layers and instiate the model again
// the reason we do this is because the model can't really be easily Sized by the compiler
// because the model can have any type of layer
// just call the function bellow
xor_model.layers[0]
.save("xor-model-first-dense.bin", 0)
.unwrap();
xor_model.layers[2]
.save("xor-model-second-dense.bin", 0)
.unwrap();

// as for loading we can just call the 'load_file' function
// on each of the layers like this:
let mut first_dense: Box<DenseF32> = Box::new(DenseF32::dummy());
first_dense.load("xor-model-first-dense.bin", 0).unwrap();
let mut second_dense: Box<DenseF32> = Box::new(DenseF32::dummy());
second_dense.load("xor-model-second-dense.bin", 0).unwrap();

let mut new_layers: Vec<Box<dyn Layer<f32>>> = Vec::new();
new_layers.push(first_dense);
new_layers.push(Box::new(TanHF32::new()));
new_layers.push(second_dense);
new_layers.push(Box::new(TanHF32::new()));

let mut loaded_xor_model = ModelF32::new(new_layers);

let loaded_model_prediction = loaded_xor_model.predict(&training_inputs, &None, &None).await;
let model_prediction = xor_model.predict(&training_inputs, &None, &None).await;

assert_eq!(loaded_model_prediction, model_prediction);
}

fn main() {
// just wait for the everything to run before stopping
pollster::block_on(run());
}
=======
&mut TrainingOptions {
learning_rate: 0.1,
loss_algorithm: MeanSquared::new(), // The Mean Squared loss function
should_print_information: true, // Should be verbose
epochs: 5000,
},
)
.unwrap();

// for saving Intricate uses the 'savefile' crate
Expand All @@ -119,4 +70,4 @@ fn main() {
let loaded_model_prediction = loaded_xor_model.get_last_prediction().unwrap();

assert_eq!(loaded_model_prediction, model_prediction);
}
}
Loading

0 comments on commit c261a50

Please sign in to comment.