Skip to content

Commit

Permalink
improve the way the workgroup_sizes are structured in the shaders
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmfern committed Jul 30, 2022
1 parent 45ba77d commit 2c22fb3
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "intricate"
version = "0.1.3"
version = "0.1.4"
edition = "2021"
license = "MIT"
keywords = ["neural-networks", "machine-learning", "backpropagation"]
Expand Down
4 changes: 2 additions & 2 deletions src/gpu/shaders/apply_gradients_to_dense_weights.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn compute_sample_weight_gradient(sample_index: u32, input_index: u32, output_in
}

@compute
@workgroup_size(255)
@workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
var input_index: u32 = global_id.x;
var output_index: u32 = global_id.y;
Expand All @@ -58,4 +58,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
var old_weight: f32 = flattened_layer_weights[flattened_weight_index];

flattened_layer_weights[flattened_weight_index] = old_weight + weight_gradient;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn compute_input_to_error_derivative(sample_index: u32, input_index: u32) -> f32
}

@compute
@workgroup_size(255)
@workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
var sample_index: u32 = global_id.x;
var input_index: u32 = global_id.y;
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/shaders/propagate_through_weights_and_biases.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn calculate_output_sample_for_all_inputs(sample_index: u32, output_index: u32)
}

@compute
@workgroup_size(255)
@workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
var sample_index = global_id.x;
var output_index = global_id.y;
Expand Down
2 changes: 1 addition & 1 deletion src/tests/gpu_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async fn should_decerase_error_test() {
&training_output_samples,
&TrainingOptionsF32 {
loss_algorithm: Box::new(MeanSquared),
learning_rate: 0.3,
learning_rate: 0.1,
should_print_information: false,
instantiate_gpu: true,
epochs: 0,
Expand Down

0 comments on commit 2c22fb3

Please sign in to comment.