Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinCortacero committed Jan 5, 2025
1 parent 4362daa commit 8ea1c7e
Show file tree
Hide file tree
Showing 42 changed files with 514 additions and 482 deletions.
8 changes: 2 additions & 6 deletions src/kartezio/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np

from kartezio.core.components import Genotype
from kartezio.enums import JSON_ELITE
from kartezio.evolution.decoder import Decoder
from kartezio.helpers import Observer
from kartezio.utils.directory import Directory

# from kartezio.utils.io import JsonSaver
from kartezio.utils.json_handler import json_write
Expand Down Expand Up @@ -164,11 +160,11 @@ def __init__(self, filename, dataset, preprocessing, fitness):
)
self.fitness = fitness.__to_dict__()

def set_decoder(self, decoder: Decoder):
def set_decoder(self, decoder):
self.decoder = decoder.__to_dict__()

def on_new_parent(self, iteration, event_content):
elite: Genotype = event_content.individuals[0].genotype
elite = event_content.individuals[0].genotype
json_data = {
"iteration": iteration,
"dataset": self.dataset,
Expand Down
4 changes: 3 additions & 1 deletion src/kartezio/cli/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
def main():
parser = argparse.ArgumentParser()
parser.add_argument("name", help="Name of your dataset", type=str)
parser.add_argument("label_name", help="Name of one entity of your set", type=str)
parser.add_argument(
"label_name", help="Name of one entity of your set", type=str
)
parser.add_argument(
"--input_type",
help="Type of your input files",
Expand Down
3 changes: 2 additions & 1 deletion src/kartezio/cli/pack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse

from kartezio.utils.io import pack_one_directory
from numena.io.drive import Directory

from kartezio.utils.io import pack_one_directory


def main():
parser = argparse.ArgumentParser()
Expand Down
179 changes: 65 additions & 114 deletions src/kartezio/core/components.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from abc import ABC, abstractmethod
from pprint import pprint
from typing import Dict, List
Expand Down Expand Up @@ -128,69 +129,73 @@ def instantiate(group_name: str, component_name: str, *args, **kwargs):
@staticmethod
def name_of(component_class: type) -> str:
if component_class.__name__ not in Components._reverse.keys():
# print(f"Component '{component_class.__name__}'
# not properly registered, please make sure use
# '@register' over your Class definition.")
return component_class.__name__
return Components._reverse[component_class.__name__].split("/")[1]

@staticmethod
def list(group_name: str):
if group_name not in Components._registry.keys():
raise KeyError(
f"Fundamental Component '{group_name}' not found in the registry."
)
return Components._registry[group_name].keys()

@staticmethod
def display():
pprint(Components._registry)

def add_as(self, fundamental: type, replace: type = None):
"""
Register a component to the Components registry.

Args:
fundamental (type): The fundamental type of the component.
replace (type): If not None, replace an existing component with the type.
def register(fundamental: type, replace: type = None):
"""
Register a component to the Components registry.
Returns:
Callable: A decorator for registering the component.
"""
fundamental_name = fundamental.__name__

def inner(item_cls):
name = item_cls.__name__
if Components._contains(fundamental_name, name):
if not replace:
raise KeyError(
f"""Error registering {fundamental_name} called '{name}'.
Here is the list of all registered {fundamental_name} components:
\n{Components._registry[fundamental_name].keys()}.
\n > Replace it using 'replace=True' in @register, or use another name.
"""
)
if replace:
replace_name = replace.__name__
if Components._contains(fundamental_name, replace_name):
print(
f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'"
)
Components.add(fundamental_name, replace_name, item_cls)
else:
Components.add(fundamental_name, name, item_cls)
return item_cls
Args:
fundamental (type): The fundamental type of the component.
replace (type): If not None, replace an existing component with the type.
return inner
Returns:
Callable: A decorator for registering the component.
"""
fundamental_name = fundamental.__name__

def inner(item_cls):
name = item_cls.__name__
if Components._contains(fundamental_name, name):
if not replace:
raise KeyError(
f"""Error registering {fundamental_name} called '{name}'.
Here is the list of all registered {fundamental_name} components:
\n{Components._registry[fundamental_name].keys()}.
\n > Replace it using 'replace=True' in @register, or use another name.
"""
)
if replace:
replace_name = replace.__name__
if Components._contains(fundamental_name, replace_name):
print(
f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'"
)
Components.add(fundamental_name, replace_name, item_cls)
else:
Components.add(fundamental_name, name, item_cls)
return item_cls

def declare(self):
"""
Register a fundamental component to the Components registry.
return inner

Returns:
Callable: A decorator for registering the fundamental component.
"""

def inner(item_cls):
Components.add_component(item_cls.__name__)
return item_cls
def fundamental():
"""
Register a fundamental component to the Components registry.
return inner
Returns:
Callable: A decorator for registering the fundamental component.
"""

def inner(item_cls):
Components.add_component(item_cls.__name__)
return item_cls

registry = Components()
return inner


def load_component(
Expand Down Expand Up @@ -224,7 +229,6 @@ def dump_component(component: KartezioComponent) -> Dict:
return base_dict


@registry.declare()
class Node(KartezioComponent, ABC):
"""
Abstract base class for a Node in the CGP framework.
Expand All @@ -233,7 +237,7 @@ class Node(KartezioComponent, ABC):
pass


@registry.declare()
@fundamental()
class Preprocessing(Node, ABC):
"""
Preprocessing node, called before training loop.
Expand Down Expand Up @@ -266,7 +270,7 @@ def then(self, preprocessing: "Preprocessing"):
return self


@registry.declare()
@fundamental()
class Primitive(Node, ABC):
"""
Primitive function called inside the CGP Graph.
Expand All @@ -287,7 +291,7 @@ def __to_dict__(self) -> Dict:
return {"name": self.name}


@registry.declare()
@fundamental()
class Genotype(KartezioComponent):
"""
Represents the genotype for Cartesian Genetic Programming (CGP).
Expand Down Expand Up @@ -393,7 +397,7 @@ def clone(self) -> "Genotype":
return copy.deepcopy(self)


@registry.declare()
@fundamental()
class Reducer(Node, ABC):
def batch(self, x: List):
y = []
Expand All @@ -406,7 +410,7 @@ def reduce(self, x):
pass


@registry.declare()
@fundamental()
class Endpoint(Node, ABC):
"""
Represents the final node in a CGP graph, responsible for producing the final outputs.
Expand Down Expand Up @@ -443,12 +447,12 @@ def __from_dict__(cls, dict_infos: Dict) -> "Endpoint":

@classmethod
def from_config(cls, config):
return registry.instantiate(
return Components.instantiate(
cls.__name__, config["name"], **config["args"]
)


@registry.declare()
@fundamental()
class Fitness(KartezioComponent, ABC):
def __init__(self, reduction="mean"):
super().__init__()
Expand Down Expand Up @@ -496,7 +500,7 @@ def __from_dict__(cls, dict_infos: Dict) -> "Fitness":
)


@registry.declare()
@fundamental()
class Library(KartezioComponent):
def __init__(self, rtype):
super().__init__()
Expand Down Expand Up @@ -603,13 +607,13 @@ def size(self):
return len(self._primitives)


@registry.declare()
@fundamental()
class Mutation(KartezioComponent, ABC):
def __init__(self, adapter):
super().__init__()
self.adapter = adapter
self.parameters = MutationUniform()
self.edges_weights = MutationEdgesUniform()
self.parameters = None # MutationUniform()
self.edges_weights = None # MutationEdgesUniform()

def random_parameters(self, chromosome: int):
return np.random.randint(
Expand Down Expand Up @@ -692,59 +696,6 @@ def __to_dict__(self) -> Dict:
return {}


@registry.declare()
@fundamental()
class Initialization(KartezioComponent, ABC):
""" """

def __init__(self):
super().__init__()


@registry.add_as(Initialization)
class CopyGenotype(Initialization):
@classmethod
def __from_dict__(cls, dict_infos: Dict) -> "CopyGenotype":
pass

def __init__(self, genotype, shape, n_functions):
super().__init__(shape, n_functions)
self.genotype = genotype

def mutate(self, genotype):
return self.genotype.clone()


@registry.add_as(Initialization)
class RandomInit(Initialization, Mutation):
"""
Can be used to initialize genome (genome) randomly
"""

@classmethod
def __from_dict__(cls, dict_infos: Dict) -> "RandomInit":
pass

def __init__(self, adapter):
super(Initialization).__init__()
super(Mutation).__init__(adapter)

def mutate(self, genotype: Genotype):
# mutate genes
for chromosome in self.adapter.chromosomes_infos.keys():
for node in range(self.adapter.n_nodes):
self.mutate_function(genotype, chromosome, node)
self.mutate_edges(genotype, chromosome, node)
self.mutate_parameters(genotype, chromosome, node)
# mutate outputs
for output in range(self.adapter.n_outputs):
self.mutate_output(genotype, output)
return genotype

def random(self):
genotype = self.adapter.new_genotype()
return self.mutate(genotype)


if __name__ == "__main__":
registry.display()
print("Done!")
pass
Loading

0 comments on commit 8ea1c7e

Please sign in to comment.