diff --git a/pySDC/projects/DAE/misc/ProblemDAE.py b/pySDC/projects/DAE/misc/ProblemDAE.py
index 161c5644cf..17e3e47eb6 100644
--- a/pySDC/projects/DAE/misc/ProblemDAE.py
+++ b/pySDC/projects/DAE/misc/ProblemDAE.py
@@ -56,28 +56,17 @@ def solve_system(self, impl_sys, u0, t):
         """
         me = self.dtype_u(self.init)
 
-        def implSysAsNumpy(unknowns, **kwargs):
-            me.diff[:] = unknowns[: np.size(me.diff)].reshape(me.diff.shape)
-            me.alg[:] = unknowns[np.size(me.diff) :].reshape(me.alg.shape)
+        def implSysFlatten(unknowns, **kwargs):
+            me[:] = unknowns.reshape(me.shape)
             sys = impl_sys(me, **kwargs)
-            return np.append(sys.diff.flatten(), sys.alg.flatten())  # TODO: more efficient way?
+            return sys.flatten()
 
-        if type(me) == DAEMesh:
-            opt = root(
-                implSysAsNumpy,
-                np.append(u0.diff.flatten(), u0.alg.flatten()),
-                method='hybr',
-                tol=self.newton_tol,
-            )
-            me.diff[:] = opt.x[: np.size(me.diff)].reshape(me.diff.shape)
-            me.alg[:] = opt.x[np.size(me.diff) :].reshape(me.alg.shape)
-        else:
-            opt = root(
-                impl_sys,
-                u0,
-                method='hybr',
-                tol=self.newton_tol,
-            )
-            me[:] = opt.x
+        opt = root(
+            implSysFlatten,
+            u0.flatten(),
+            method='hybr',
+            tol=self.newton_tol,
+        )
+        me[:] = opt.x.reshape(me.shape)
         self.work_counters['newton'].niter += opt.nfev
         return me
diff --git a/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py b/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py
index 35f9d2fbed..22cea1f035 100644
--- a/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py
+++ b/pySDC/projects/DAE/sweepers/fully_implicit_DAE.py
@@ -107,11 +107,8 @@ def implSystem(params):
                     System to be solved as implicit function.
                 """
 
-                if type(params) == DAEMesh:
-                    params_mesh = P.dtype_f(params)
-                else:
-                    params_mesh = P.dtype_f(P.init)
-                    params_mesh[:] = params
+                params_mesh = P.dtype_f(P.init)
+                params_mesh[:] = params.reshape(params_mesh.shape)
 
                 # build parameters to pass to implicit function
                 local_u_approx = P.dtype_f(u_approx)