diff --git a/pySDC/projects/DAE/misc/ProblemDAE.py b/pySDC/projects/DAE/misc/ProblemDAE.py index 161c5644cf..17e3e47eb6 100644 --- a/pySDC/projects/DAE/misc/ProblemDAE.py +++ b/pySDC/projects/DAE/misc/ProblemDAE.py @@ -56,28 +56,17 @@ def solve_system(self, impl_sys, u0, t): """ me = self.dtype_u(self.init) - def implSysAsNumpy(unknowns, **kwargs): - me.diff[:] = unknowns[: np.size(me.diff)].reshape(me.diff.shape) - me.alg[:] = unknowns[np.size(me.diff) :].reshape(me.alg.shape) + def implSysFlatten(unknowns, **kwargs): + me[:] = unknowns.reshape(me.shape) sys = impl_sys(me, **kwargs) - return np.append(sys.diff.flatten(), sys.alg.flatten()) # TODO: more efficient way? + return sys.flatten() - if type(me) == DAEMesh: - opt = root( - implSysAsNumpy, - np.append(u0.diff.flatten(), u0.alg.flatten()), - method='hybr', - tol=self.newton_tol, - ) - me.diff[:] = opt.x[: np.size(me.diff)].reshape(me.diff.shape) - me.alg[:] = opt.x[np.size(me.diff) :].reshape(me.alg.shape) - else: - opt = root( - impl_sys, - u0, - method='hybr', - tol=self.newton_tol, - ) - me[:] = opt.x + opt = root( + implSysFlatten, + u0.flatten(), + method='hybr', + tol=self.newton_tol, + ) + me[:] = opt.x.reshape(me.shape) self.work_counters['newton'].niter += opt.nfev return me diff --git a/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py b/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py index 35f9d2fbed..22cea1f035 100644 --- a/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py +++ b/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py @@ -107,11 +107,8 @@ def implSystem(params): System to be solved as implicit function. """ - if type(params) == DAEMesh: - params_mesh = P.dtype_f(params) - else: - params_mesh = P.dtype_f(P.init) - params_mesh[:] = params + params_mesh = P.dtype_f(P.init) + params_mesh[:] = params.reshape(params_mesh.shape) # build parameters to pass to implicit function local_u_approx = P.dtype_f(u_approx)