diff --git a/pySDC/playgrounds/ML_initial_guess/README.md b/pySDC/playgrounds/ML_initial_guess/README.md new file mode 100644 index 0000000000..617578512a --- /dev/null +++ b/pySDC/playgrounds/ML_initial_guess/README.md @@ -0,0 +1,19 @@ +Machine learning initial guesses for SDC +---------------------------------------- + +Most linear solves in SDC are performed in the first iteration. Afterwards, SDC providing good initial guesses is actually one of its strengths. +To get a better initial guess for "free", and to stay hip, we want to do this with machine learning. + +This playground is very much work in progress! +The first thing we did was to build a simple datatype for PyTorch that we can use in pySDC. Keep in mind that it is very inefficient and I don't think it works with MPI yet. But it's good enough for counting iterations. Once we have a proof of concept, we should refine this. +Then, we setup a simple heat equation with this datatype in `heat.py`. +The crucial new function is `ML_predict`, which loads an already trained model and evaluates it. +This, in turn, is called during `predict` in the sweeper. (See `sweeper.py`) +But we need to train the model, of course. This is done in `ml_heat.py`. + +How to move on with this project: +================================= +The first thing you might want to do is to fix the neural network that solves the heat equation. Our first try was too simplistic. +The next thing would be to not predict the solution at a single node, but at all collocation nodes simultaneously. Maybe, actually start with this. +If you get a proof of concept, you can clean up the datatype, such that it is even fast. +You can do a "physics-informed" learning process of predicting the entire collocation solution by means of the residual. This is very generic, actually. diff --git a/pySDC/playgrounds/ML_initial_guess/heat.py b/pySDC/playgrounds/ML_initial_guess/heat.py new file mode 100644 index 0000000000..a6661b922a --- /dev/null +++ b/pySDC/playgrounds/ML_initial_guess/heat.py @@ -0,0 +1,262 @@ +import numpy as np +import scipy.sparse as sp +from scipy.sparse.linalg import gmres, spsolve, cg +import torch +from pySDC.core.Errors import ProblemError +from pySDC.core.Problem import ptype, WorkCounter +from pySDC.helpers import problem_helper +from pySDC.implementations.datatype_classes.mesh import mesh +from pySDC.playgrounds.ML_initial_guess.tensor import Tensor +from pySDC.playgrounds.ML_initial_guess.sweeper import GenericImplicitML_IG +from pySDC.tutorial.step_1.A_spatial_problem_setup import run_accuracy_check +from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI +from pySDC.playgrounds.ML_initial_guess.ml_heat import HeatEquationModel + + +class Heat1DFDTensor(ptype): + """ + Very simple 1-dimensional finite differences implementation of a heat equation using the pySDC-PyTorch interface. + Still includes some mess. + """ + + dtype_u = Tensor + dtype_f = Tensor + + def __init__( + self, + nvars=256, + nu=1.0, + freq=4, + stencil_type='center', + order=2, + lintol=1e-12, + liniter=10000, + solver_type='direct', + bc='periodic', + bcParams=None, + ): + # make sure parameters have the correct types + if not type(nvars) in [int, tuple]: + raise ProblemError('nvars should be either tuple or int') + if not type(freq) in [int, tuple]: + raise ProblemError('freq should be either tuple or int') + + ndim = 1 + + # eventually extend freq to other dimension + if type(freq) is int: + freq = (freq,) * ndim + if len(freq) != ndim: + raise ProblemError(f'len(freq)={len(freq)}, different to ndim={ndim}') + + # check values for freq and nvars + for f in freq: + if ndim == 1 and f == -1: + # use Gaussian initial solution in 1D + bc = 'periodic' + break + if f % 2 != 0 and bc == 'periodic': + raise ProblemError('need even number of frequencies due to periodic BCs') + + # invoke super init, passing number of dofs + super().__init__(init=(torch.empty(size=(nvars,), dtype=torch.double), None, np.dtype('float64'))) + + dx, xvalues = problem_helper.get_1d_grid(size=nvars, bc=bc, left_boundary=0.0, right_boundary=1.0) + + self.A_, _ = problem_helper.get_finite_difference_matrix( + derivative=2, + order=order, + stencil_type=stencil_type, + dx=dx, + size=nvars, + dim=ndim, + bc=bc, + ) + self.A_ *= nu + self.A = torch.tensor(self.A_.todense()) + + self.xvalues = torch.tensor(xvalues, dtype=torch.double) + self.Id = torch.tensor((sp.eye(nvars, format='csc')).todense()) + + # store attribute and register them as parameters + self._makeAttributeAndRegister('nvars', 'stencil_type', 'order', 'bc', 'nu', localVars=locals(), readOnly=True) + self._makeAttributeAndRegister('freq', 'lintol', 'liniter', 'solver_type', localVars=locals()) + + if self.solver_type != 'direct': + self.work_counters[self.solver_type] = WorkCounter() + + @property + def ndim(self): + """Number of dimensions of the spatial problem""" + return 1 + + @property + def dx(self): + """Size of the mesh (in all dimensions)""" + return self.xvalues[1] - self.xvalues[0] + + @property + def grids(self): + """ND grids associated to the problem""" + x = self.xvalues + if self.ndim == 1: + return x + if self.ndim == 2: + return x[None, :], x[:, None] + if self.ndim == 3: + return x[None, :, None], x[:, None, None], x[None, None, :] + + def eval_f(self, u, t): + """ + Routine to evaluate the right-hand side of the problem. + + Parameters + ---------- + u : dtype_u + Current values. + t : float + Current time. + + Returns + ------- + f : dtype_f + Values of the right-hand side of the problem. + """ + f = self.f_init + f[:] = torch.matmul(self.A, u) + return f + + def ML_predict(self, u0, t0, dt): + """ + Predict the solution at t0+dt given initial conditions u0 + """ + # read in model + model = HeatEquationModel(self) + model.load_state_dict(torch.load('heat_equation_model.pth')) + model.eval() + + # evaluate model + predicted_state = model(u0, t0, dt) + sol = self.u_init + sol[:] = predicted_state.double()[:] + return sol + + def solve_system(self, rhs, factor, u0, t): + r""" + Simple linear solver for :math:`(I-factor\cdot A)\vec{u}=\vec{rhs}`. + + Parameters + ---------- + rhs : dtype_f + Right-hand side for the linear system. + factor : float + Abbrev. for the local stepsize (or any other factor required). + u0 : dtype_u + Initial guess for the iterative solver. + t : float + Current time (e.g. for time-dependent BCs). + + Returns + ------- + sol : dtype_u + The solution of the linear solver. + """ + solver_type, Id, A, nvars, sol = ( + self.solver_type, + self.Id, + self.A, + self.nvars, + self.u_init, + ) + + if solver_type == 'direct': + sol[:] = torch.linalg.solve(Id - factor * A, rhs.flatten()).reshape(nvars) + # TODO: implement torch equivalent of cg + # elif solver_type == 'CG': + # sol[:] = cg( + # Id - factor * A, + # rhs.flatten(), + # x0=u0.flatten(), + # tol=lintol, + # maxiter=liniter, + # atol=0, + # callback=self.work_counters[solver_type], + # )[0].reshape(nvars) + else: + raise ValueError(f'solver type "{solver_type}" not known!') + + return sol + + def u_exact(self, t, **kwargs): + r""" + Routine to compute the exact solution at time :math:`t`. + + Parameters + ---------- + t : float + Time of the exact solution. + + Returns + ------- + sol : dtype_u + The exact solution. + """ + if 'u_init' in kwargs.keys() or 't_init' in kwargs.keys(): + self.logger.warning( + f'{type(self).__name__} uses an analytic exact solution from t=0. If you try to compute the local error, you will get the global error instead!' + ) + + ndim, freq, nu, dx, sol = self.ndim, self.freq, self.nu, self.dx, self.u_init + + if ndim == 1: + x = self.grids + rho = (2.0 - 2.0 * torch.cos(np.pi * freq[0] * dx)) / dx**2 + if freq[0] > 0: + sol[:] = torch.sin(np.pi * freq[0] * x) * torch.exp(-t * nu * rho) + else: + raise NotImplementedError + + return sol + + +def main(): + """ + A simple test program to setup a full step instance + """ + dt = 1e-2 + + level_params = dict() + level_params['restol'] = 1e-10 + level_params['dt'] = dt + + sweeper_params = dict() + sweeper_params['quad_type'] = 'RADAU-RIGHT' + sweeper_params['num_nodes'] = 3 + sweeper_params['QI'] = 'LU' + sweeper_params['initial_guess'] = 'NN' + + problem_params = dict() + + step_params = dict() + step_params['maxiter'] = 20 + + description = dict() + description['problem_class'] = Heat1DFDTensor + description['problem_params'] = problem_params + description['sweeper_class'] = GenericImplicitML_IG + description['sweeper_params'] = sweeper_params + description['level_params'] = level_params + description['step_params'] = step_params + + controller = controller_nonMPI(num_procs=1, controller_params={'logger_level': 20}, description=description) + + P = controller.MS[0].levels[0].prob + + uinit = P.u_exact(0) + uend, _ = controller.run(u0=uinit, t0=0, Tend=dt) + u_exact = P.u_exact(dt) + print("error ", torch.abs(u_exact - uend).max()) + + +if __name__ == "__main__": + main() diff --git a/pySDC/playgrounds/ML_initial_guess/ml_heat.py b/pySDC/playgrounds/ML_initial_guess/ml_heat.py new file mode 100644 index 0000000000..c3286869f0 --- /dev/null +++ b/pySDC/playgrounds/ML_initial_guess/ml_heat.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import matplotlib.pyplot as plt +import numpy as np + + +class Train_pySDC: + """ + Interface between PyTorch and pySDC for training models. + + Attributes: + - problem: An instantiated problem from pySDC that allows evaluating the exact solution. + This should have the same parameters as the problem you run in pySDC later. + - model: A PyTorch model with some neural network to train, specific to the problem + """ + + def __init__(self, problem, model, use_exact=True): + self.problem = problem + self.model = model + self.use_exact = use_exact # use exact solution in problem class or backward Euler solution + + self.model.train(True) + + def generate_initial_condition(self, t): + return self.problem.u_exact(t) + + def generate_target_condition(self, initial_condition, t, dt): + if self.use_exact: + return self.problem.u_exact(t + dt) + else: + return self.problem.solve_system(initial_condition, dt, initial_condition, t) + + def train_model(self, initial_condition=None, t=None, dt=None, num_epochs=1000, lr=0.001): + model = self.model + + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + # setup initial and target conditions + t = torch.rand(1) if t is None else t + dt = torch.rand(1) if dt is None else dt + initial_condition = self.generate_initial_condition(t) if initial_condition is None else initial_condition + target_condition = self.generate_target_condition(initial_condition, t, dt) + + # do the training + for epoch in range(num_epochs): + predicted_state = model(initial_condition, t, dt) + loss = criterion(predicted_state.float(), target_condition.float()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (epoch + 1) % 100 == 0 or True: + print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') + + def plot(self, initial_condition=None, t=None, dt=None): + t = torch.rand(1) if t is None else t + dt = torch.rand(1) if dt is None else dt + initial_condition = self.generate_initial_condition(t) if initial_condition is None else initial_condition + target = self.generate_target_condition(initial_condition, t, dt) + model_prediction = self.model(initial_condition, t, dt) + + fig, ax = plt.subplots() + ax.plot(self.problem.xvalues, initial_condition, label='ic') + ax.plot(self.problem.xvalues, target, label='target') + ax.plot(self.problem.xvalues, model_prediction.detach().numpy(), label='model') + ax.set_title(f't={t:.2e}, dt={dt:.2e}') + ax.legend() + + +class HeatEquationModel(nn.Module): + """ + Very simple model to learn the heat equation. Beware! It's too simple. + Some machine learning expert please fix this! + """ + + def __init__(self, problem, hidden_size=64): + self.input_size = problem.nvars * 3 + self.output_size = problem.nvars + + super().__init__() + + self.fc1 = nn.Linear(self.input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, self.output_size) + + # Initialize weights (example) + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x, t, dt): + # prepare individual tensors + x = x.float() + _t = torch.ones_like(x) * t + _dt = torch.ones_like(x) * dt + + # Concatenate t and dt with the input x + _x = torch.cat((x, _t, _dt), dim=0) + + _x = self.fc1(_x) + _x = self.relu(_x) + _x = self.fc2(_x) + return _x + + +def train_at_collocation_nodes(): + """ + For the first proof of concept, we want to train the model specifically to the collocation nodes we use in SDC. + If successful, the initial guess would already be the exact solution and we would need no SDC iterations. + Alas, this neural network is too simple... We need **you** to fix it! + """ + collocation_nodes = np.array([0.15505102572168285, 1, 0.6449489742783183]) * 1e-2 + + from pySDC.playgrounds.ML_initial_guess.heat import Heat1DFDTensor + + prob = Heat1DFDTensor() + model = HeatEquationModel(prob) + trainer = Train_pySDC(prob, model, use_exact=True) + for dt in collocation_nodes: + trainer.train_model(num_epochs=50, t=0, dt=dt) + for dt in collocation_nodes: + trainer.plot(t=0, dt=dt) + torch.save(model.state_dict(), 'heat_equation_model.pth') + plt.show() + + +if __name__ == '__main__': + train_at_collocation_nodes() diff --git a/pySDC/playgrounds/ML_initial_guess/sweeper.py b/pySDC/playgrounds/ML_initial_guess/sweeper.py new file mode 100644 index 0000000000..df30979f76 --- /dev/null +++ b/pySDC/playgrounds/ML_initial_guess/sweeper.py @@ -0,0 +1,24 @@ +from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit + + +class GenericImplicitML_IG(generic_implicit): + def predict(self): + """ + Initialise node with machine learning initial guess + """ + if self.params.initial_guess != 'NN': + return super().predict() + + L = self.level + P = L.prob + + # evaluate RHS at left point + L.f[0] = P.eval_f(L.u[0], L.time) + + for m in range(1, self.coll.num_nodes + 1): + L.u[m] = P.ML_predict(L.u[0], L.time, L.dt * self.coll.nodes[m - 1]) + L.f[m] = P.eval_f(L.u[m], L.time + L.dt * self.coll.nodes[m - 1]) + + # indicate that this level is now ready for sweeps + L.status.unlocked = True + L.status.updated = True diff --git a/pySDC/playgrounds/ML_initial_guess/tensor.py b/pySDC/playgrounds/ML_initial_guess/tensor.py new file mode 100644 index 0000000000..c28c321213 --- /dev/null +++ b/pySDC/playgrounds/ML_initial_guess/tensor.py @@ -0,0 +1,131 @@ +import numpy as np +import torch + +from pySDC.core.Errors import DataError + +try: + # TODO : mpi4py cannot be imported before dolfin when using fenics mesh + # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590 + # This should be dealt with at some point + from mpi4py import MPI +except ImportError: + MPI = None + + +class Tensor(torch.Tensor): + """ + Wrapper for PyTorch tensor. + Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this project goes much further! + + TODO: Have to update `torch/multiprocessing/reductions.py` in order to share this datatype across processes. + + Attributes: + _comm: MPI communicator or None + """ + + @staticmethod + def __new__(cls, init, val=0.0, *args, **kwargs): + """ + Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. + + Args: + init: either another mesh or a tuple containing the dimensions, the communicator and the dtype + val: value to initialize + + Returns: + obj of type mesh + + """ + if isinstance(init, Tensor): + obj = super().__new__(cls, init) + obj[:] = init[:] + obj._comm = init._comm + elif ( + isinstance(init, tuple) + # and (init[1] is None or isinstance(init[1], MPI.Intracomm)) + # and isinstance(init[2], np.dtype) + ): + obj = super().__new__(cls, init[0].clone()) + obj.fill_(val) + obj._comm = init[1] + else: + raise NotImplementedError(type(init)) + return obj + + @property + def comm(self): + """ + Getter for the communicator + """ + return self._comm + + def __array_finalize__(self, obj): + """ + Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator. + """ + if obj is None: + return + self._comm = getattr(obj, '_comm', None) + + def __abs__(self): + """ + Overloading the abs operator + + Returns: + float: absolute maximum of all mesh values + """ + # take absolute values of the mesh values + local_absval = float(torch.amax(torch.abs(self))) + + if self.comm is not None: + if self.comm.Get_size() > 1: + global_absval = 0.0 + global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval) + else: + global_absval = local_absval + else: + global_absval = local_absval + + return float(global_absval) + + def isend(self, dest=None, tag=None, comm=None): + """ + Routine for sending data forward in time (non-blocking) + + Args: + dest (int): target rank + tag (int): communication tag + comm: communicator + + Returns: + request handle + """ + return comm.Issend(self[:], dest=dest, tag=tag) + + def irecv(self, source=None, tag=None, comm=None): + """ + Routine for receiving in time + + Args: + source (int): source rank + tag (int): communication tag + comm: communicator + + Returns: + None + """ + return comm.Irecv(self[:], source=source, tag=tag) + + def bcast(self, root=None, comm=None): + """ + Routine for broadcasting values + + Args: + root (int): process with value to broadcast + comm: communicator + + Returns: + broadcasted values + """ + comm.Bcast(self[:], root=root) + return self