Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Cortacero committed Jan 4, 2025
1 parent 85f3362 commit 4362daa
Show file tree
Hide file tree
Showing 4 changed files with 781 additions and 225 deletions.
124 changes: 68 additions & 56 deletions src/kartezio/core/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,58 +138,59 @@ def name_of(component_class: type) -> str:
def display():
pprint(Components._registry)

def add_as(self, fundamental: type, replace: type = None):
"""
Register a component to the Components registry.
def register(fundamental: type, replace: type = None):
"""
Register a component to the Components registry.
Args:
fundamental (type): The fundamental type of the component.
replace (type): If not None, replace an existing component with the type.
Args:
fundamental (type): The fundamental type of the component.
replace (type): If not None, replace an existing component with the type.
Returns:
Callable: A decorator for registering the component.
"""
fundamental_name = fundamental.__name__

def inner(item_cls):
name = item_cls.__name__
if Components._contains(fundamental_name, name):
if not replace:
raise KeyError(
f"""Error registering {fundamental_name} called '{name}'.
Here is the list of all registered {fundamental_name} components:
\n{Components._registry[fundamental_name].keys()}.
\n > Replace it using 'replace=True' in @register, or use another name.
"""
)
if replace:
replace_name = replace.__name__
if Components._contains(fundamental_name, replace_name):
print(
f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'"
)
Components.add(fundamental_name, replace_name, item_cls)
else:
Components.add(fundamental_name, name, item_cls)
return item_cls

Returns:
Callable: A decorator for registering the component.
"""
fundamental_name = fundamental.__name__

def inner(item_cls):
name = item_cls.__name__
if Components._contains(fundamental_name, name):
if not replace:
raise KeyError(
f"""Error registering {fundamental_name} called '{name}'.
Here is the list of all registered {fundamental_name} components:
\n{Components._registry[fundamental_name].keys()}.
\n > Replace it using 'replace=True' in @register, or use another name.
"""
)
if replace:
replace_name = replace.__name__
if Components._contains(fundamental_name, replace_name):
print(
f"Component '{fundamental_name}/{replace_name}' will be replaced by '{name}'"
)
Components.add(fundamental_name, replace_name, item_cls)
else:
Components.add(fundamental_name, name, item_cls)
return item_cls
return inner

return inner
def declare(self):
"""
Register a fundamental component to the Components registry.
Returns:
Callable: A decorator for registering the fundamental component.
"""

def component():
"""
Register a fundamental component to the Components registry.
def inner(item_cls):
Components.add_component(item_cls.__name__)
return item_cls

Returns:
Callable: A decorator for registering the fundamental component.
"""
return inner

def inner(item_cls):
Components.add_component(item_cls.__name__)
return item_cls

return inner
registry = Components()


def load_component(
Expand Down Expand Up @@ -223,7 +224,7 @@ def dump_component(component: KartezioComponent) -> Dict:
return base_dict


@component()
@registry.declare()
class Node(KartezioComponent, ABC):
"""
Abstract base class for a Node in the CGP framework.
Expand All @@ -232,7 +233,7 @@ class Node(KartezioComponent, ABC):
pass


@component()
@registry.declare()
class Preprocessing(Node, ABC):
"""
Preprocessing node, called before training loop.
Expand Down Expand Up @@ -265,7 +266,7 @@ def then(self, preprocessing: "Preprocessing"):
return self


@component()
@registry.declare()
class Primitive(Node, ABC):
"""
Primitive function called inside the CGP Graph.
Expand All @@ -286,7 +287,7 @@ def __to_dict__(self) -> Dict:
return {"name": self.name}


@component()
@registry.declare()
class Genotype(KartezioComponent):
"""
Represents the genotype for Cartesian Genetic Programming (CGP).
Expand Down Expand Up @@ -392,7 +393,7 @@ def clone(self) -> "Genotype":
return copy.deepcopy(self)


@component()
@registry.declare()
class Reducer(Node, ABC):
def batch(self, x: List):
y = []
Expand All @@ -405,7 +406,7 @@ def reduce(self, x):
pass


@component()
@registry.declare()
class Endpoint(Node, ABC):
"""
Represents the final node in a CGP graph, responsible for producing the final outputs.
Expand All @@ -425,7 +426,6 @@ def __init__(self, inputs: List[KType]):

@classmethod
def __from_dict__(cls, dict_infos: Dict) -> "Endpoint":
from kartezio.core.endpoints import Endpoint
"""
Create an Endpoint instance from a dictionary representation.
Expand All @@ -441,8 +441,14 @@ def __from_dict__(cls, dict_infos: Dict) -> "Endpoint":
**dict_infos["args"],
)

@classmethod
def from_config(cls, config):
return registry.instantiate(
cls.__name__, config["name"], **config["args"]
)


@component()
@registry.declare()
class Fitness(KartezioComponent, ABC):
def __init__(self, reduction="mean"):
super().__init__()
Expand Down Expand Up @@ -482,14 +488,15 @@ def evaluate(self, y_true, y_pred):
@classmethod
def __from_dict__(cls, dict_infos: Dict) -> "Fitness":
from kartezio.core.fitness import Fitness

return Components.instantiate(
"Fitness",
dict_infos["name"],
**dict_infos["args"],
)


@component()
@registry.declare()
class Library(KartezioComponent):
def __init__(self, rtype):
super().__init__()
Expand Down Expand Up @@ -596,7 +603,7 @@ def size(self):
return len(self._primitives)


@component()
@registry.declare()
class Mutation(KartezioComponent, ABC):
def __init__(self, adapter):
super().__init__()
Expand Down Expand Up @@ -685,15 +692,15 @@ def __to_dict__(self) -> Dict:
return {}


@component()
@registry.declare()
class Initialization(KartezioComponent, ABC):
""" """

def __init__(self):
super().__init__()


@register(Initialization)
@registry.add_as(Initialization)
class CopyGenotype(Initialization):
@classmethod
def __from_dict__(cls, dict_infos: Dict) -> "CopyGenotype":
Expand All @@ -707,7 +714,7 @@ def mutate(self, genotype):
return self.genotype.clone()


@register(Initialization)
@registry.add_as(Initialization)
class RandomInit(Initialization, Mutation):
"""
Can be used to initialize genome (genome) randomly
Expand Down Expand Up @@ -736,3 +743,8 @@ def mutate(self, genotype: Genotype):
def random(self):
genotype = self.adapter.new_genotype()
return self.mutate(genotype)


if __name__ == "__main__":
registry.display()
print("Done!")
Loading

0 comments on commit 4362daa

Please sign in to comment.