Skip to content

Commit

Permalink
#4776 initial fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rtimms committed Jan 27, 2025
1 parent 43d67db commit 122ca4c
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 7 deletions.
7 changes: 7 additions & 0 deletions src/pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,13 @@ def process_symbol(self, symbol):
]
else:
discretised_symbol.secondary_mesh = None

# Assign tertiary mesh
if symbol.domains["tertiary"] != []:
discretised_symbol.tertiary_mesh = self.mesh[symbol.domains["tertiary"]]
else:
discretised_symbol.tertiary_mesh = None

return discretised_symbol

def _process_symbol(self, symbol):
Expand Down
9 changes: 8 additions & 1 deletion src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def generic_deserialise(cls, instance, properties):
else:
var.secondary_mesh = None

if var.domains["tertiary"] != []:
var.tertiary_mesh = properties["mesh"][var.domains["tertiary"]]
else:
var.tertiary_mesh = None

if properties["geometry"]:
instance._geometry = pybamm.Geometry(properties["geometry"])
else:
Expand Down Expand Up @@ -875,8 +880,10 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model
final_state_eval = final_state[:, -1]
elif final_state.ndim == 3:
final_state_eval = final_state[:, :, -1].flatten(order="F")
elif final_state.ndim == 4:
final_state_eval = final_state[:, :, :, -1].flatten(order="F")
else:
raise NotImplementedError("Variable must be 0D, 1D, or 2D")
raise NotImplementedError("Variable must be 0D, 1D, 2D, or 3D")
elif isinstance(var, pybamm.Concatenation):
children = []
for child in var.orphans:
Expand Down
238 changes: 233 additions & 5 deletions src/pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,221 @@ def _interp_setup(self, entries, t):
return entries, coords_for_interp


class ProcessedVariable3D(ProcessedVariable):
"""
An object that can be evaluated at arbitrary (scalars or vectors) t and x, and
returns the (interpolated) value of the base variable at that t and x.
Parameters
----------
base_variables : list of :class:`pybamm.Symbol`
A list of base variables with a method `evaluate(t,y)`, each entry of which
returns the value of that variable for that particular sub-solution.
A Solution can be comprised of sub-solutions which are the solutions of
different models.
Note that this can be any kind of node in the expression tree, not
just a :class:`pybamm.Variable`.
When evaluated, returns an array of size (m,n)
base_variables_casadi : list of :class:`casadi.Function`
A list of casadi functions. When evaluated, returns the same thing as
`base_Variable.evaluate` (but more efficiently).
solution : :class:`pybamm.Solution`
The solution object to be used to create the processed variables
"""

def __init__(
self,
base_variables,
base_variables_casadi,
solution,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.dimensions = 3
super().__init__(
base_variables,
base_variables_casadi,
solution,
time_integral=time_integral,
)
first_dim_nodes = self.mesh.nodes
first_dim_edges = self.mesh.edges
second_dim_nodes = self.base_variables[0].secondary_mesh.nodes
third_dim_nodes = self.base_variables[0].tertiary_mesh.nodes
if self.base_eval_size // (len(second_dim_nodes) * len(third_dim_nodes)) == len(
first_dim_nodes
):
first_dim_pts = first_dim_nodes
elif self.base_eval_size // (
len(second_dim_nodes) * len(third_dim_nodes)
) == len(first_dim_edges):
first_dim_pts = first_dim_edges

second_dim_pts = second_dim_nodes
third_dim_pts = third_dim_nodes
self.first_dim_size = len(first_dim_pts)
self.second_dim_size = len(second_dim_pts)
self.third_dim_size = len(third_dim_pts)

def _observe_raw_python(self):
"""
Initialise a 3D object that depends on x, y, and z or x, r, and R.
"""
pybamm.logger.debug("Observing the variable raw data in Python")
first_dim_size, second_dim_size, t_size = self._shape(self.t_pts)
entries = np.empty((first_dim_size, second_dim_size, t_size))

# Evaluate the base_variable index-by-index
idx = 0
for ts, ys, inputs, base_var_casadi in zip(
self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
):
for inner_idx, t in enumerate(ts):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[:, :, idx] = np.reshape(
base_var_casadi(t, y, inputs).full(),
[first_dim_size, second_dim_size],
order="F",
)
idx += 1
return entries

def _interp_setup(self, entries, t):
"""
Initialise a 3D object that depends on x, y, and z, or x, r, and R.
"""
first_dim_nodes = self.mesh.nodes
first_dim_edges = self.mesh.edges
second_dim_nodes = self.base_variables[0].secondary_mesh.nodes
second_dim_edges = self.base_variables[0].secondary_mesh.edges
third_dim_nodes = self.base_variables[0].tertiary_mesh.nodes
third_dim_edges = self.base_variables[0].tertiary_mesh.edges
if self.base_eval_size // (len(second_dim_nodes) * len(third_dim_nodes)) == len(
first_dim_nodes
):
first_dim_pts = first_dim_nodes
elif self.base_eval_size // (
len(second_dim_nodes) * len(third_dim_nodes)
) == len(first_dim_edges):
first_dim_pts = first_dim_edges

second_dim_pts = second_dim_nodes
third_dim_pts = third_dim_nodes

# add points outside first dimension domain for extrapolation to
# boundaries
extrap_space_first_dim_left = np.array(
[2 * first_dim_pts[0] - first_dim_pts[1]]
)
extrap_space_first_dim_right = np.array(
[2 * first_dim_pts[-1] - first_dim_pts[-2]]
)
first_dim_pts = np.concatenate(
[extrap_space_first_dim_left, first_dim_pts, extrap_space_first_dim_right]
)
extrap_entries_left = np.expand_dims(2 * entries[0] - entries[1], axis=0)
extrap_entries_right = np.expand_dims(2 * entries[-1] - entries[-2], axis=0)
entries_for_interp = np.concatenate(
[extrap_entries_left, entries, extrap_entries_right], axis=0
)

# add points outside second dimension domain for extrapolation to
# boundaries
extrap_space_second_dim_left = np.array(
[2 * second_dim_pts[0] - second_dim_pts[1]]
)
extrap_space_second_dim_right = np.array(
[2 * second_dim_pts[-1] - second_dim_pts[-2]]
)
second_dim_pts = np.concatenate(
[
extrap_space_second_dim_left,
second_dim_pts,
extrap_space_second_dim_right,
]
)
extrap_entries_second_dim_left = np.expand_dims(
2 * entries_for_interp[:, 0, :] - entries_for_interp[:, 1, :], axis=1
)
extrap_entries_second_dim_right = np.expand_dims(
2 * entries_for_interp[:, -1, :] - entries_for_interp[:, -2, :], axis=1
)
entries_for_interp = np.concatenate(
[
extrap_entries_second_dim_left,
entries_for_interp,
extrap_entries_second_dim_right,
],
axis=1,
)

# add points outside tertiary dimension domain for extrapolation to
# boundaries
extrap_space_third_dim_left = np.array(
[2 * third_dim_pts[0] - third_dim_pts[1]]
)
extrap_space_third_dim_right = np.array(
[2 * third_dim_pts[-1] - third_dim_pts[-2]]
)
third_dim_pts = np.concatenate(
[
extrap_space_third_dim_left,
third_dim_pts,
extrap_space_third_dim_right,
]
)
extrap_entries_third_dim_left = np.expand_dims(
2 * entries_for_interp[:, :, 0] - entries_for_interp[:, :, 1], axis=2
)
extrap_entries_third_dim_right = np.expand_dims(
2 * entries_for_interp[:, :, -1] - entries_for_interp[:, :, -2], axis=2
)
entries_for_interp = np.concatenate(
[
extrap_entries_third_dim_left,
entries_for_interp,
extrap_entries_third_dim_right,
],
axis=2,
)

self.spatial_variable_names = {
k: self._process_spatial_variable_names(v)
for k, v in self.spatial_variables.items()
}

self.first_dimension = self.spatial_variable_names["primary"]
self.second_dimension = self.spatial_variable_names["secondary"]
self.third_dimension = self.spatial_variable_names["tertiary"]

# assign attributes for reference
first_dim_pts_for_interp = first_dim_pts
second_dim_pts_for_interp = second_dim_pts
third_dim_pts_for_interp = third_dim_pts

# Set pts to edges for nicer plotting
self.first_dim_pts = first_dim_edges
self.second_dim_pts = second_dim_edges
self.third_dim_pts = third_dim_edges

# save attributes for interpolation
coords_for_interp = {
self.first_dimension: first_dim_pts_for_interp,
self.second_dimension: second_dim_pts_for_interp,
self.third_dimension: third_dim_pts_for_interp,
"t": t,
}

return entries_for_interp, coords_for_interp

def _shape(self, t):
first_dim_size = self.first_dim_size
second_dim_size = self.second_dim_size
third_dim_size = self.third_dim_size
t_size = len(t)
return [first_dim_size, second_dim_size, third_dim_size, t_size]


def process_variable(base_variables, *args, **kwargs):
mesh = base_variables[0].mesh
domain = base_variables[0].domain
Expand All @@ -875,6 +1090,15 @@ def process_variable(base_variables, *args, **kwargs):
and isinstance(mesh, pybamm.ScikitSubMesh2D)
):
return ProcessedVariable2DSciKitFEM(base_variables, *args, **kwargs)
if (
base_variables[0].secondary_mesh
and "current collector" in base_variables[0].domains["secondary"]
and isinstance(base_variables[0].secondary_mesh, pybamm.ScikitSubMesh2D)
):
raise NotImplementedError(
"3D variables with secondary domain 'current collector' using the ScikitFEM"
" discretisation are not supported as processed variables"
)

# check variable shape
if len(base_eval_shape) == 0 or base_eval_shape[0] == 1:
Expand All @@ -896,11 +1120,15 @@ def process_variable(base_variables, *args, **kwargs):
]:
return ProcessedVariable2D(base_variables, *args, **kwargs)

# Raise error for 3D variable
raise NotImplementedError(
f"Shape not recognized for {base_variables[0]}"
+ "(note processing of 3D variables is not yet implemented)"
)
# Try some shapes that could make the variable a 3D variable
tertiary_pts = base_variables[0].tertiary_mesh.nodes
if base_eval_size // (len(second_dim_pts) * len(tertiary_pts)) in [
len(first_dim_nodes),
len(first_dim_edges),
]:
return ProcessedVariable3D(base_variables, *args, **kwargs)

raise NotImplementedError(f"Shape not recognized for {base_variables[0]}")


def _is_f_contiguous(all_ys):
Expand Down
Loading

0 comments on commit 122ca4c

Please sign in to comment.