From 3b9fc9ea94b27ddeb5af1e005ef93a1bd253366c Mon Sep 17 00:00:00 2001 From: Xiaocheng Liao Date: Fri, 2 Feb 2024 20:45:28 +1300 Subject: [PATCH] add PrimitiveTree.expression and its test --- deap/gp.py | 65 +++++++++++++++++++++++++++++++-- tests/test_gp.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 tests/test_gp.py diff --git a/deap/gp.py b/deap/gp.py index ccc47b6b..a0cfc2a4 100644 --- a/deap/gp.py +++ b/deap/gp.py @@ -28,6 +28,7 @@ import types import warnings from inspect import isclass +from typing import Dict, Callable from collections import defaultdict, deque from functools import partial, wraps @@ -185,6 +186,64 @@ def searchSubtree(self, begin): end += 1 return slice(begin, end) + def expression(self, sym_mapping: Dict[str, Callable] = None, **kwargs) -> str: + """ + Return the simplified mathematical expression according to the defined mappings + + Author: Rabbyt + + Parameters: + - sym_mapping (python dict): A dictionary that maps strings to functions. + + Returns: + - str: The simplified mathematical expression. + + Example: + >>> tree = gp.PrimitiveTree.from_string("min(sqrt(pow(x0, 2)), min(x1, add(x0, x0)))", pset=pset) + >>> func_sym_mapping = { + 'add': operator.add, + 'pow': operator.pow, + 'sqrt': sp.sqrt, + 'min': sp.Min + } + >>> terminal_sym_mapping = { + 'x0': sp.Symbol('x0', negative=True), + 'x1': sp.Symbol('x0', positive=True) + } + >>> tree.expression(func_sym_mapping, **terminal_sym_mapping) + 2*x0 + """ + try: + import sympy + except ImportError: + raise ImportError("Sympy needs to be installed " + f"before calling {PrimitiveTree.__name__}.expression.") + if sym_mapping is None: + sym_mapping = {} + string = "" + stack = [] + for node in self: + stack.append((node, [])) + while len(stack[-1][1]) == stack[-1][0].arity: + prim, args = stack.pop() + if isinstance(prim, Primitive): + sym_func = sym_mapping.get(prim.name) + if sym_func is None: + string = sympy.Function(prim.name)(*args) + else: + string = sym_func(*args) + else: + if prim.conv_fct is str: + string = kwargs.get(prim.value) + if string is None: + string = sympy.Symbol(prim.value) + else: + string = prim.value + if len(stack) == 0: + break + stack[-1][1].append(string) + return str(string) + class Primitive(object): """Class that encapsulates a primitive and when called with arguments it @@ -363,9 +422,9 @@ def addPrimitive(self, primitive, in_types, ret_type, name=None): assert name not in self.context or \ self.context[name] is primitive, \ - "Primitives are required to have a unique name. " \ - "Consider using the argument 'name' to rename your " \ - "second '%s' primitive." % (name,) + "Primitives are required to have a unique name. " \ + "Consider using the argument 'name' to rename your " \ + "second '%s' primitive." % (name,) self._add(prim) self.context[prim.name] = primitive diff --git a/tests/test_gp.py b/tests/test_gp.py new file mode 100644 index 00000000..7c28e94f --- /dev/null +++ b/tests/test_gp.py @@ -0,0 +1,95 @@ +# This file is part of DEAP. +# +# DEAP is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of +# the License, or (at your option) any later version. +# +# DEAP is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with DEAP. If not, see . + +import unittest +import re +import operator +from deap import gp + +try: + import sympy as sp +except ImportError: + sp = None + + +class TestGP(unittest.TestCase): + def setUp(self): + pass + + @unittest.skipUnless(sp, 'Sympy is not installed') + def test_gp_expression_func(self): + pset = gp.PrimitiveSet("MAIN", arity=5, prefix='x') + pset.addPrimitive(operator.add, 2) + pset.addPrimitive(operator.sub, 2) + pset.addPrimitive(operator.mul, 2) + pset.addPrimitive(operator.truediv, 2) + pset.addPrimitive(operator.neg, 1) + + tree = gp.PrimitiveTree.from_string("add(neg(x0), mul(x1, truediv(x2, sub(x3, x4))))", pset=pset) + func_sym_mapping = { + 'add': operator.add, + 'sub': operator.sub, + 'mul': operator.mul, + 'neg': operator.neg + } + expr_string = tree.expression(func_sym_mapping) + expr_string = re.sub(r'\s', '', expr_string) + self.assertEqual(expr_string, '-x0+x1*truediv(x2,x3-x4)') + + func_sym_mapping['truediv'] = operator.truediv + + expr_string = tree.expression(func_sym_mapping) + expr_string = re.sub(r'\s', '', expr_string) + self.assertEqual(expr_string, '-x0+x1*x2/(x3-x4)') + + @unittest.skipUnless(sp, 'Sympy is not installed') + def test_gp_expression_symbol(self): + pset = gp.PrimitiveSet("MAIN", arity=2, prefix='x') + pset.addPrimitive(operator.add, 2) + pset.addPrimitive(operator.pow, 2) + pset.addPrimitive(sp.sqrt, 1) + pset.addPrimitive(min, 2) + + tree = gp.PrimitiveTree.from_string("min(sqrt(pow(x0, 2)), min(x1, add(x0, x0)))", pset=pset) + func_sym_mapping = { + 'add': operator.add, + 'pow': operator.pow, + 'sqrt': sp.sqrt, + 'min': sp.Min + } + + terminal_sym_mapping = { + 'x0': sp.Symbol('x0', real=True) + } + expr_string = tree.expression(func_sym_mapping, **terminal_sym_mapping) + expr_string = re.sub(r'\s', '', expr_string) + self.assertEqual(expr_string, 'Min(2*x0,x1,Abs(x0))') + + terminal_sym_mapping = { + 'x0': sp.Symbol('x0', positive=True) + } + expr_string = tree.expression(func_sym_mapping, **terminal_sym_mapping) + expr_string = re.sub(r'\s', '', expr_string) + self.assertEqual(expr_string, 'Min(x0,x1)') + + terminal_sym_mapping = { + 'x0': sp.Symbol('x0', negative=True), + 'x1': sp.Symbol('x0', positive=True) + } + expr_string = tree.expression(func_sym_mapping, **terminal_sym_mapping) + expr_string = re.sub(r'\s', '', expr_string) + self.assertEqual(expr_string, '2*x0') + +