diff --git a/pySDC/implementations/datatype_classes/MultiComponentMesh.py b/pySDC/implementations/datatype_classes/MultiComponentMesh.py index 029a8613d6..990ab130fb 100644 --- a/pySDC/implementations/datatype_classes/MultiComponentMesh.py +++ b/pySDC/implementations/datatype_classes/MultiComponentMesh.py @@ -1,3 +1,5 @@ +import numpy as np + from pySDC.implementations.datatype_classes.mesh import mesh @@ -13,9 +15,15 @@ class MultiComponentMesh(mesh): components = [] - def __new__(cls, init, *args, **kwargs): - if isinstance(init, tuple): + def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None, *args, **kwargs): + if isinstance(init, tuple) and isinstance(init[0], int): obj = super().__new__(cls, ((len(cls.components), init[0]), *init[1:]), *args, **kwargs) + elif isinstance(init, tuple) and isinstance(init[0], tuple): + obj = np.ndarray.__new__( + cls, (len(cls.components), *init[0]), dtype=init[2], buffer=buffer, offset=offset, strides=strides, order=order + ) + obj.fill(val) + obj._comm = init[1] else: obj = super().__new__(cls, init, *args, **kwargs)