Skip to content

Commit

Permalink
feat: improve analyzer plot functions
Browse files Browse the repository at this point in the history
  • Loading branch information
suhdonghwi committed Jun 26, 2021
1 parent 528468a commit 8247751
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 89 deletions.
26 changes: 11 additions & 15 deletions analysis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,23 @@ def plot_fitness_max(cases):
plt.ylabel("Fitness")


def plot_succ_gens(cases):
def plot_succ_gens(cases, fitness_threshold):
succ_gens = []
for gens in cases:
for (i, gen) in enumerate(gens):
if gen.fitness_max >= 3.9:
if gen.fitness_max >= fitness_threshold:
succ_gens.append(i)
break

print("Success : " + str(len(succ_gens)))
sns.histplot(succ_gens, kde=True)


def plot_size(cases):
def plot_size(cases, fitness_threshold):
sizes = []
for gens in cases:
for (i, gen) in enumerate(gens):
if gen.fitness_max >= 3.9:
if gen.fitness_max >= fitness_threshold:
sizes.append(gen.best_edges_count)
break

Expand All @@ -95,27 +95,23 @@ def plot_size(cases):


case1 = split_cases("./analysis/output.txt")
case2 = split_cases("./analysis/output-no.txt")
fitness_threshold = 3.95
cases = [(case1, "")]

cases = [case1, case2]
labels = ["With destructive mutation", "Without destructive mutation"]

for case in cases:
for case, _ in cases:
plot_fitness_max(case)
plt.show()


for case, label in zip(cases, labels):
plot_succ_gens(case)
for case, label in cases:
plot_succ_gens(case, fitness_threshold)

plt.title("survival rate = " + label)
plt.xlabel("Generation")
plt.show()


for case, label in zip(cases, labels):
plot_size(case)
for case, label in cases:
plot_size(case, fitness_threshold)

plt.title("survival rate = " + label)
plt.xlabel("Size")
plt.show()
72 changes: 38 additions & 34 deletions examples/sin-no-gui.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,52 @@
#![recursion_limit = "512"]
mod helper;

use petgraph::dot::{Config, Dot};

use neat::network::Network;
use neat::{innovation_record::InnovationRecord, network::feedforward::Feedforward, pool::Pool};

pub fn main() {
let args = helper::cli::get_arguments();
let params = helper::read_parameters_file("./params/sin.toml");

let mut innov_record = InnovationRecord::new(params.input_number, params.output_number);
let mut pool = Pool::<Feedforward>::new(params, args.verbosity, &mut innov_record);

loop {
let best_genome = pool.evaluate(|_, network| {
let n = 50;
let mut error_sum = 0.0;

for i in -n..=n {
let x = i as f64 / n as f64;

let output = network.activate(&[x]).unwrap()[0];
let expected = (x * std::f64::consts::PI).sin();
let err = output - expected;

error_sum += err * err;
for _ in 0..10 {
println!("<Case Start>");

let params = helper::read_parameters_file("./params/sin.toml");
let mut innov_record = InnovationRecord::new(params.input_number, params.output_number);
let mut pool = Pool::<Feedforward>::new(params, args.verbosity, &mut innov_record);

for _ in 0..500 {
pool.evaluate(|_, network| {
let n = 50;
let mut error_sum = 0.0;

for i in -n..=n {
let x = i as f64 / n as f64;

let output = network.activate(&[x]).unwrap()[0];
let expected = (x * std::f64::consts::PI).sin();
let err = output - expected;

error_sum += err * err;
}

let error_mean = error_sum / (n * 2 + 1) as f64;
network.evaluate(4.0 - error_mean);
});

/*
if best_genome.fitness().unwrap() > 3.999 {
let dot = Dot::with_attr_getters(
best_genome.graph().inner_data(),
&[Config::NodeNoLabel, Config::EdgeNoLabel],
&|_, data| format!("label = \"{:.2}\"", data.weight().get_weight()),
&|_, (index, _)| format!("label = \"{}\"", index.index()),
);
println!("{:?}", dot);
break;
}
*/

let error_mean = error_sum / (n * 2 + 1) as f64;
network.evaluate(4.0 - error_mean);
});

if best_genome.fitness().unwrap() > 3.999 {
let dot = Dot::with_attr_getters(
best_genome.graph().inner_data(),
&[Config::NodeNoLabel, Config::EdgeNoLabel],
&|_, data| format!("label = \"{:.2}\"", data.weight().get_weight()),
&|_, (index, _)| format!("label = \"{}\"", index.index()),
);
println!("{:?}", dot);
break;
pool.evolve(&mut innov_record);
}

pool.evolve(&mut innov_record);
}
}
10 changes: 5 additions & 5 deletions examples/sin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ impl MainState {
font,
);
let mut sin_points = Vec::new();
for i in -100..=100 {
sin_points.push(na::Point2::new(i as f32 / 100.0, 0.0));
for i in -50..=50 {
sin_points.push(na::Point2::new(i as f32 / 50.0, 0.0));
}

MainState {
Expand All @@ -70,19 +70,19 @@ impl event::EventHandler for MainState {
fn update(&mut self, ctx: &mut ggez::Context) -> ggez::GameResult {
self.timer += ggez::timer::delta(ctx);

if self.timer >= Duration::from_secs_f64(0.05) {
if self.timer >= Duration::from_secs_f64(0.00) {
let generation = self.pool.generation();
let mut best_network = self
.pool
.evaluate(|_, network| {
let n = 100;
let n = 50;
let mut error_sum = 0.0;

for i in -n..=n {
let x = 1.0 * i as f64 / n as f64;

let output = network.activate(&[x]).unwrap()[0];
let expected = (1.0 * x * std::f64::consts::PI).sin() * 0.5;
let expected = (2.0 * x * std::f64::consts::PI).sin() * 0.5;
let err = output - expected;

error_sum += err * err;
Expand Down
45 changes: 22 additions & 23 deletions examples/xor-no-gui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,35 @@ use neat::network::Network;
use neat::{innovation_record::InnovationRecord, network::feedforward::Feedforward, pool::Pool};

pub fn main() {
println!("<Case Start>");
for _ in 0..500 {
println!("<Case Start>");

let args = helper::cli::get_arguments();
let params = helper::read_parameters_file("./params/xor.toml");
let args = helper::cli::get_arguments();
let params = helper::read_parameters_file("./params/xor.toml");

let mut innov_record = InnovationRecord::new(2, 1);
let mut pool = Pool::<Feedforward>::new(params, args.verbosity, &mut innov_record);
let mut innov_record = InnovationRecord::new(2, 1);
let mut pool = Pool::<Feedforward>::new(params, args.verbosity, &mut innov_record);

let data = vec![
(vec![0.0, 0.0], 0.0),
(vec![0.0, 1.0], 1.0),
(vec![1.0, 0.0], 1.0),
(vec![1.0, 1.0], 0.0),
];
let data = vec![
(vec![0.0, 0.0], 0.0),
(vec![0.0, 1.0], 1.0),
(vec![1.0, 0.0], 1.0),
(vec![1.0, 1.0], 0.0),
];

for _ in 0..100 {
let best_genome = pool.evaluate(|_, network| {
let mut err = 0.0;
for _ in 0..200 {
pool.evaluate(|_, network| {
let mut err = 0.0;

for (inputs, expected) in &data {
let output = network.activate(inputs).unwrap()[0];
err += (output - expected) * (output - expected);
}
for (inputs, expected) in &data {
let output = network.activate(inputs).unwrap()[0];
err += (output - expected) * (output - expected);
}

network.evaluate(4.0 - err);
});
network.evaluate(4.0 - err);
});

if best_genome.fitness().unwrap() > 3.9 {
break;
pool.evolve(&mut innov_record);
}
pool.evolve(&mut innov_record);
}
}
14 changes: 7 additions & 7 deletions params/sin.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
input_number = 1
output_number = 1
population = 150
population = 300

hidden_activation = 'Tanh'
output_activation = 'Linear'
Expand All @@ -17,16 +17,16 @@ remove_node = 0.2
weight_min = -3.0
weight_max = 3.0

perturb_min = -0.5
perturb_max = 0.5
perturb_min = -0.1
perturb_max = 0.1

[speciation]
c1 = 1.0
c2 = 0.5
compatibility_threshold = 1.5
elitism = 2
compatibility_threshold = 2.0
elitism = 5

survival_rate = 0.2
survival_rate = 0.4

[reproduction]
crossover_rate = 0.75
crossover_rate = 0.0
8 changes: 4 additions & 4 deletions params/xor.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ remove_node = 0.2
weight_min = -10.0
weight_max = 10.0

perturb_min = -2.0
perturb_max = 2.0
perturb_min = -1.0
perturb_max = 1.0

[speciation]
c1 = 1.0
c2 = 0.5
compatibility_threshold = 15.0
elitism = 5

survival_rate = 0.5
survival_rate = 0.1

[reproduction]
crossover_rate = 0.75
crossover_rate = 0.3
2 changes: 1 addition & 1 deletion src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl<'a, T: Network + Debug + Clone> Pool<T> {

species_set = species_set
.into_iter()
.filter(|s| s.genome_count() > 2)
.filter(|s| s.genome_count() > 1)
.collect();
if species_set.is_empty() {
panic!("remaining species_set size is 0; maybe compatibility threshold is too small?");
Expand Down

0 comments on commit 8247751

Please sign in to comment.