From d412fdcee65174eedca61ea2f19463d9c0f7ad97 Mon Sep 17 00:00:00 2001
From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com>
Date: Fri, 26 Jan 2024 08:28:32 +0100
Subject: [PATCH] Started playground for machine learning generated initial
 guesses for (#394)

SDC
---
 pySDC/playgrounds/ML_initial_guess/README.md  |  19 ++
 pySDC/playgrounds/ML_initial_guess/heat.py    | 262 ++++++++++++++++++
 pySDC/playgrounds/ML_initial_guess/ml_heat.py | 130 +++++++++
 pySDC/playgrounds/ML_initial_guess/sweeper.py |  24 ++
 pySDC/playgrounds/ML_initial_guess/tensor.py  | 131 +++++++++
 5 files changed, 566 insertions(+)
 create mode 100644 pySDC/playgrounds/ML_initial_guess/README.md
 create mode 100644 pySDC/playgrounds/ML_initial_guess/heat.py
 create mode 100644 pySDC/playgrounds/ML_initial_guess/ml_heat.py
 create mode 100644 pySDC/playgrounds/ML_initial_guess/sweeper.py
 create mode 100644 pySDC/playgrounds/ML_initial_guess/tensor.py

diff --git a/pySDC/playgrounds/ML_initial_guess/README.md b/pySDC/playgrounds/ML_initial_guess/README.md
new file mode 100644
index 0000000000..617578512a
--- /dev/null
+++ b/pySDC/playgrounds/ML_initial_guess/README.md
@@ -0,0 +1,19 @@
+Machine learning initial guesses for SDC
+----------------------------------------
+
+Most linear solves in SDC are performed in the first iteration. Afterwards, SDC providing good initial guesses is actually one of its strengths.
+To get a better initial guess for "free", and to stay hip, we want to do this with machine learning.
+
+This playground is very much work in progress!
+The first thing we did was to build a simple datatype for PyTorch that we can use in pySDC. Keep in mind that it is very inefficient and I don't think it works with MPI yet. But it's good enough for counting iterations. Once we have a proof of concept, we should refine this.
+Then, we setup a simple heat equation with this datatype in `heat.py`.
+The crucial new function is `ML_predict`, which loads an already trained model and evaluates it.
+This, in turn, is called during `predict` in the sweeper. (See `sweeper.py`)
+But we need to train the model, of course. This is done in `ml_heat.py`.
+
+How to move on with this project:
+=================================
+The first thing you might want to do is to fix the neural network that solves the heat equation. Our first try was too simplistic.
+The next thing would be to not predict the solution at a single node, but at all collocation nodes simultaneously. Maybe, actually start with this.
+If you get a proof of concept, you can clean up the datatype, such that it is even fast.
+You can do a "physics-informed" learning process of predicting the entire collocation solution by means of the residual. This is very generic, actually.
diff --git a/pySDC/playgrounds/ML_initial_guess/heat.py b/pySDC/playgrounds/ML_initial_guess/heat.py
new file mode 100644
index 0000000000..a6661b922a
--- /dev/null
+++ b/pySDC/playgrounds/ML_initial_guess/heat.py
@@ -0,0 +1,262 @@
+import numpy as np
+import scipy.sparse as sp
+from scipy.sparse.linalg import gmres, spsolve, cg
+import torch
+from pySDC.core.Errors import ProblemError
+from pySDC.core.Problem import ptype, WorkCounter
+from pySDC.helpers import problem_helper
+from pySDC.implementations.datatype_classes.mesh import mesh
+from pySDC.playgrounds.ML_initial_guess.tensor import Tensor
+from pySDC.playgrounds.ML_initial_guess.sweeper import GenericImplicitML_IG
+from pySDC.tutorial.step_1.A_spatial_problem_setup import run_accuracy_check
+from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel
+
+
+class Heat1DFDTensor(ptype):
+    """
+    Very simple 1-dimensional finite differences implementation of a heat equation using the pySDC-PyTorch interface.
+    Still includes some mess.
+    """
+
+    dtype_u = Tensor
+    dtype_f = Tensor
+
+    def __init__(
+        self,
+        nvars=256,
+        nu=1.0,
+        freq=4,
+        stencil_type='center',
+        order=2,
+        lintol=1e-12,
+        liniter=10000,
+        solver_type='direct',
+        bc='periodic',
+        bcParams=None,
+    ):
+        # make sure parameters have the correct types
+        if not type(nvars) in [int, tuple]:
+            raise ProblemError('nvars should be either tuple or int')
+        if not type(freq) in [int, tuple]:
+            raise ProblemError('freq should be either tuple or int')
+
+        ndim = 1
+
+        # eventually extend freq to other dimension
+        if type(freq) is int:
+            freq = (freq,) * ndim
+        if len(freq) != ndim:
+            raise ProblemError(f'len(freq)={len(freq)}, different to ndim={ndim}')
+
+        # check values for freq and nvars
+        for f in freq:
+            if ndim == 1 and f == -1:
+                # use Gaussian initial solution in 1D
+                bc = 'periodic'
+                break
+            if f % 2 != 0 and bc == 'periodic':
+                raise ProblemError('need even number of frequencies due to periodic BCs')
+
+        # invoke super init, passing number of dofs
+        super().__init__(init=(torch.empty(size=(nvars,), dtype=torch.double), None, np.dtype('float64')))
+
+        dx, xvalues = problem_helper.get_1d_grid(size=nvars, bc=bc, left_boundary=0.0, right_boundary=1.0)
+
+        self.A_, _ = problem_helper.get_finite_difference_matrix(
+            derivative=2,
+            order=order,
+            stencil_type=stencil_type,
+            dx=dx,
+            size=nvars,
+            dim=ndim,
+            bc=bc,
+        )
+        self.A_ *= nu
+        self.A = torch.tensor(self.A_.todense())
+
+        self.xvalues = torch.tensor(xvalues, dtype=torch.double)
+        self.Id = torch.tensor((sp.eye(nvars, format='csc')).todense())
+
+        # store attribute and register them as parameters
+        self._makeAttributeAndRegister('nvars', 'stencil_type', 'order', 'bc', 'nu', localVars=locals(), readOnly=True)
+        self._makeAttributeAndRegister('freq', 'lintol', 'liniter', 'solver_type', localVars=locals())
+
+        if self.solver_type != 'direct':
+            self.work_counters[self.solver_type] = WorkCounter()
+
+    @property
+    def ndim(self):
+        """Number of dimensions of the spatial problem"""
+        return 1
+
+    @property
+    def dx(self):
+        """Size of the mesh (in all dimensions)"""
+        return self.xvalues[1] - self.xvalues[0]
+
+    @property
+    def grids(self):
+        """ND grids associated to the problem"""
+        x = self.xvalues
+        if self.ndim == 1:
+            return x
+        if self.ndim == 2:
+            return x[None, :], x[:, None]
+        if self.ndim == 3:
+            return x[None, :, None], x[:, None, None], x[None, None, :]
+
+    def eval_f(self, u, t):
+        """
+        Routine to evaluate the right-hand side of the problem.
+
+        Parameters
+        ----------
+        u : dtype_u
+            Current values.
+        t : float
+            Current time.
+
+        Returns
+        -------
+        f : dtype_f
+            Values of the right-hand side of the problem.
+        """
+        f = self.f_init
+        f[:] = torch.matmul(self.A, u)
+        return f
+
+    def ML_predict(self, u0, t0, dt):
+        """
+        Predict the solution at t0+dt given initial conditions u0
+        """
+        # read in model
+        model = HeatEquationModel(self)
+        model.load_state_dict(torch.load('heat_equation_model.pth'))
+        model.eval()
+
+        # evaluate model
+        predicted_state = model(u0, t0, dt)
+        sol = self.u_init
+        sol[:] = predicted_state.double()[:]
+        return sol
+
+    def solve_system(self, rhs, factor, u0, t):
+        r"""
+        Simple linear solver for :math:`(I-factor\cdot A)\vec{u}=\vec{rhs}`.
+
+        Parameters
+        ----------
+        rhs : dtype_f
+            Right-hand side for the linear system.
+        factor : float
+            Abbrev. for the local stepsize (or any other factor required).
+        u0 : dtype_u
+            Initial guess for the iterative solver.
+        t : float
+            Current time (e.g. for time-dependent BCs).
+
+        Returns
+        -------
+        sol : dtype_u
+            The solution of the linear solver.
+        """
+        solver_type, Id, A, nvars, sol = (
+            self.solver_type,
+            self.Id,
+            self.A,
+            self.nvars,
+            self.u_init,
+        )
+
+        if solver_type == 'direct':
+            sol[:] = torch.linalg.solve(Id - factor * A, rhs.flatten()).reshape(nvars)
+        # TODO: implement torch equivalent of cg
+        # elif solver_type == 'CG':
+        #     sol[:] = cg(
+        #         Id - factor * A,
+        #         rhs.flatten(),
+        #         x0=u0.flatten(),
+        #         tol=lintol,
+        #         maxiter=liniter,
+        #         atol=0,
+        #         callback=self.work_counters[solver_type],
+        #     )[0].reshape(nvars)
+        else:
+            raise ValueError(f'solver type "{solver_type}" not known!')
+
+        return sol
+
+    def u_exact(self, t, **kwargs):
+        r"""
+        Routine to compute the exact solution at time :math:`t`.
+
+        Parameters
+        ----------
+        t : float
+            Time of the exact solution.
+
+        Returns
+        -------
+        sol : dtype_u
+            The exact solution.
+        """
+        if 'u_init' in kwargs.keys() or 't_init' in kwargs.keys():
+            self.logger.warning(
+                f'{type(self).__name__} uses an analytic exact solution from t=0. If you try to compute the local error, you will get the global error instead!'
+            )
+
+        ndim, freq, nu, dx, sol = self.ndim, self.freq, self.nu, self.dx, self.u_init
+
+        if ndim == 1:
+            x = self.grids
+            rho = (2.0 - 2.0 * torch.cos(np.pi * freq[0] * dx)) / dx**2
+            if freq[0] > 0:
+                sol[:] = torch.sin(np.pi * freq[0] * x) * torch.exp(-t * nu * rho)
+        else:
+            raise NotImplementedError
+
+        return sol
+
+
+def main():
+    """
+    A simple test program to setup a full step instance
+    """
+    dt = 1e-2
+
+    level_params = dict()
+    level_params['restol'] = 1e-10
+    level_params['dt'] = dt
+
+    sweeper_params = dict()
+    sweeper_params['quad_type'] = 'RADAU-RIGHT'
+    sweeper_params['num_nodes'] = 3
+    sweeper_params['QI'] = 'LU'
+    sweeper_params['initial_guess'] = 'NN'
+
+    problem_params = dict()
+
+    step_params = dict()
+    step_params['maxiter'] = 20
+
+    description = dict()
+    description['problem_class'] = Heat1DFDTensor
+    description['problem_params'] = problem_params
+    description['sweeper_class'] = GenericImplicitML_IG
+    description['sweeper_params'] = sweeper_params
+    description['level_params'] = level_params
+    description['step_params'] = step_params
+
+    controller = controller_nonMPI(num_procs=1, controller_params={'logger_level': 20}, description=description)
+
+    P = controller.MS[0].levels[0].prob
+
+    uinit = P.u_exact(0)
+    uend, _ = controller.run(u0=uinit, t0=0, Tend=dt)
+    u_exact = P.u_exact(dt)
+    print("error ", torch.abs(u_exact - uend).max())
+
+
+if __name__ == "__main__":
+    main()
diff --git a/pySDC/playgrounds/ML_initial_guess/ml_heat.py b/pySDC/playgrounds/ML_initial_guess/ml_heat.py
new file mode 100644
index 0000000000..c3286869f0
--- /dev/null
+++ b/pySDC/playgrounds/ML_initial_guess/ml_heat.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+class Train_pySDC:
+    """
+    Interface between PyTorch and pySDC for training models.
+
+    Attributes:
+     - problem: An instantiated problem from pySDC that allows evaluating the exact solution.
+                This should have the same parameters as the problem you run in pySDC later.
+     - model: A PyTorch model with some neural network to train, specific to the problem
+    """
+
+    def __init__(self, problem, model, use_exact=True):
+        self.problem = problem
+        self.model = model
+        self.use_exact = use_exact  # use exact solution in problem class or backward Euler solution
+
+        self.model.train(True)
+
+    def generate_initial_condition(self, t):
+        return self.problem.u_exact(t)
+
+    def generate_target_condition(self, initial_condition, t, dt):
+        if self.use_exact:
+            return self.problem.u_exact(t + dt)
+        else:
+            return self.problem.solve_system(initial_condition, dt, initial_condition, t)
+
+    def train_model(self, initial_condition=None, t=None, dt=None, num_epochs=1000, lr=0.001):
+        model = self.model
+
+        criterion = nn.MSELoss()
+        optimizer = optim.Adam(model.parameters(), lr=lr)
+
+        # setup initial and target conditions
+        t = torch.rand(1) if t is None else t
+        dt = torch.rand(1) if dt is None else dt
+        initial_condition = self.generate_initial_condition(t) if initial_condition is None else initial_condition
+        target_condition = self.generate_target_condition(initial_condition, t, dt)
+
+        # do the training
+        for epoch in range(num_epochs):
+            predicted_state = model(initial_condition, t, dt)
+            loss = criterion(predicted_state.float(), target_condition.float())
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            if (epoch + 1) % 100 == 0 or True:
+                print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
+
+    def plot(self, initial_condition=None, t=None, dt=None):
+        t = torch.rand(1) if t is None else t
+        dt = torch.rand(1) if dt is None else dt
+        initial_condition = self.generate_initial_condition(t) if initial_condition is None else initial_condition
+        target = self.generate_target_condition(initial_condition, t, dt)
+        model_prediction = self.model(initial_condition, t, dt)
+
+        fig, ax = plt.subplots()
+        ax.plot(self.problem.xvalues, initial_condition, label='ic')
+        ax.plot(self.problem.xvalues, target, label='target')
+        ax.plot(self.problem.xvalues, model_prediction.detach().numpy(), label='model')
+        ax.set_title(f't={t:.2e}, dt={dt:.2e}')
+        ax.legend()
+
+
+class HeatEquationModel(nn.Module):
+    """
+    Very simple model to learn the heat equation. Beware! It's too simple.
+    Some machine learning expert please fix this!
+    """
+
+    def __init__(self, problem, hidden_size=64):
+        self.input_size = problem.nvars * 3
+        self.output_size = problem.nvars
+
+        super().__init__()
+
+        self.fc1 = nn.Linear(self.input_size, hidden_size)
+        self.relu = nn.ReLU()
+        self.fc2 = nn.Linear(hidden_size, self.output_size)
+
+        # Initialize weights (example)
+        nn.init.xavier_uniform_(self.fc1.weight)
+        nn.init.xavier_uniform_(self.fc2.weight)
+
+    def forward(self, x, t, dt):
+        # prepare individual tensors
+        x = x.float()
+        _t = torch.ones_like(x) * t
+        _dt = torch.ones_like(x) * dt
+
+        # Concatenate t and dt with the input x
+        _x = torch.cat((x, _t, _dt), dim=0)
+
+        _x = self.fc1(_x)
+        _x = self.relu(_x)
+        _x = self.fc2(_x)
+        return _x
+
+
+def train_at_collocation_nodes():
+    """
+    For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC.
+    If successful, the initial guess would already be the exact solution and we would need no SDC iterations.
+    Alas, this neural network is too simple... We need **you** to fix it!
+    """
+    collocation_nodes = np.array([0.15505102572168285, 1, 0.6449489742783183]) * 1e-2
+
+    from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor
+
+    prob = Heat1DFDTensor()
+    model = HeatEquationModel(prob)
+    trainer = Train_pySDC(prob, model, use_exact=True)
+    for dt in collocation_nodes:
+        trainer.train_model(num_epochs=50, t=0, dt=dt)
+    for dt in collocation_nodes:
+        trainer.plot(t=0, dt=dt)
+    torch.save(model.state_dict(), 'heat_equation_model.pth')
+    plt.show()
+
+
+if __name__ == '__main__':
+    train_at_collocation_nodes()
diff --git a/pySDC/playgrounds/ML_initial_guess/sweeper.py b/pySDC/playgrounds/ML_initial_guess/sweeper.py
new file mode 100644
index 0000000000..df30979f76
--- /dev/null
+++ b/pySDC/playgrounds/ML_initial_guess/sweeper.py
@@ -0,0 +1,24 @@
+from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
+
+
+class GenericImplicitML_IG(generic_implicit):
+    def predict(self):
+        """
+        Initialise node with  machine learning initial guess
+        """
+        if self.params.initial_guess != 'NN':
+            return super().predict()
+
+        L = self.level
+        P = L.prob
+
+        # evaluate RHS at left point
+        L.f[0] = P.eval_f(L.u[0], L.time)
+
+        for m in range(1, self.coll.num_nodes + 1):
+            L.u[m] = P.ML_predict(L.u[0], L.time, L.dt * self.coll.nodes[m - 1])
+            L.f[m] = P.eval_f(L.u[m], L.time + L.dt * self.coll.nodes[m - 1])
+
+        # indicate that this level is now ready for sweeps
+        L.status.unlocked = True
+        L.status.updated = True
diff --git a/pySDC/playgrounds/ML_initial_guess/tensor.py b/pySDC/playgrounds/ML_initial_guess/tensor.py
new file mode 100644
index 0000000000..c28c321213
--- /dev/null
+++ b/pySDC/playgrounds/ML_initial_guess/tensor.py
@@ -0,0 +1,131 @@
+import numpy as np
+import torch
+
+from pySDC.core.Errors import DataError
+
+try:
+    # TODO : mpi4py cannot be imported before dolfin when using fenics mesh
+    # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
+    # This should be dealt with at some point
+    from mpi4py import MPI
+except ImportError:
+    MPI = None
+
+
+class Tensor(torch.Tensor):
+    """
+    Wrapper for PyTorch tensor.
+    Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this project goes much further!
+
+    TODO: Have to update `torch/multiprocessing/reductions.py` in order to share this datatype across processes.
+
+    Attributes:
+        _comm: MPI communicator or None
+    """
+
+    @staticmethod
+    def __new__(cls, init, val=0.0, *args, **kwargs):
+        """
+        Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
+
+        Args:
+            init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
+            val: value to initialize
+
+        Returns:
+            obj of type mesh
+
+        """
+        if isinstance(init, Tensor):
+            obj = super().__new__(cls, init)
+            obj[:] = init[:]
+            obj._comm = init._comm
+        elif (
+            isinstance(init, tuple)
+            # and (init[1] is None or isinstance(init[1], MPI.Intracomm))
+            # and isinstance(init[2], np.dtype)
+        ):
+            obj = super().__new__(cls, init[0].clone())
+            obj.fill_(val)
+            obj._comm = init[1]
+        else:
+            raise NotImplementedError(type(init))
+        return obj
+
+    @property
+    def comm(self):
+        """
+        Getter for the communicator
+        """
+        return self._comm
+
+    def __array_finalize__(self, obj):
+        """
+        Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
+        """
+        if obj is None:
+            return
+        self._comm = getattr(obj, '_comm', None)
+
+    def __abs__(self):
+        """
+        Overloading the abs operator
+
+        Returns:
+            float: absolute maximum of all mesh values
+        """
+        # take absolute values of the mesh values
+        local_absval = float(torch.amax(torch.abs(self)))
+
+        if self.comm is not None:
+            if self.comm.Get_size() > 1:
+                global_absval = 0.0
+                global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval)
+            else:
+                global_absval = local_absval
+        else:
+            global_absval = local_absval
+
+        return float(global_absval)
+
+    def isend(self, dest=None, tag=None, comm=None):
+        """
+        Routine for sending data forward in time (non-blocking)
+
+        Args:
+            dest (int): target rank
+            tag (int): communication tag
+            comm: communicator
+
+        Returns:
+            request handle
+        """
+        return comm.Issend(self[:], dest=dest, tag=tag)
+
+    def irecv(self, source=None, tag=None, comm=None):
+        """
+        Routine for receiving in time
+
+        Args:
+            source (int): source rank
+            tag (int): communication tag
+            comm: communicator
+
+        Returns:
+            None
+        """
+        return comm.Irecv(self[:], source=source, tag=tag)
+
+    def bcast(self, root=None, comm=None):
+        """
+        Routine for broadcasting values
+
+        Args:
+            root (int): process with value to broadcast
+            comm: communicator
+
+        Returns:
+            broadcasted values
+        """
+        comm.Bcast(self[:], root=root)
+        return self