Skip to content

Commit

Permalink
Merge branch 'main' into chamel/multi-branch-hyper-visco
Browse files Browse the repository at this point in the history
  • Loading branch information
cmhamel authored Nov 5, 2024
2 parents 75569bd + 4f154bf commit 510afb6
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions optimism/FunctionSpace.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from collections import namedtuple
import numpy as onp

import jax
import jax.numpy as np
from jax.scipy.linalg import solve

from jaxtyping import Array, Float
from optimism import Interpolants
from optimism import Mesh
from optimism import QuadratureRule
from typing import Tuple
import equinox as eqx
import jax
import jax.numpy as np
import numpy as onp


class EssentialBC(eqx.Module):
nodeSet: str
component: int

FunctionSpace = namedtuple('FunctionSpace', ['shapes', 'vols', 'shapeGrads', 'mesh', 'quadratureRule', 'isAxisymmetric'])
FunctionSpace.__doc__ = \

class FunctionSpace(eqx.Module):
"""Data needed for calculus on functions in the discrete function space.
In describing the shape of the attributes, ``ne`` is the number of
Expand All @@ -32,8 +37,12 @@
isAxisymmetric: boolean indicating if the function space data are
axisymmetric.
"""

EssentialBC = namedtuple('EssentialBC', ['nodeSet', 'component'])
shapes: Float[Array, "ne nqpe nn"]
vols: Float[Array, "ne nqpe"]
shapeGrads: Float[Array, "ne nqpe nn nd"]
mesh: Mesh.Mesh
quadratureRule: QuadratureRule.QuadratureRule
isAxisymmetric: bool


def construct_function_space(mesh, quadratureRule, mode2D='cartesian'):
Expand Down Expand Up @@ -354,25 +363,42 @@ def integrate_function_on_edges(functionSpace, func, U, quadRule, edges):
return np.sum(integrate_on_edges(functionSpace, func, U, quadRule, edges))


class DofManager:
class DofManager(eqx.Module):
# TODO get type hints below correct
# TODO this one could be moved to jax types if we move towards
# TODO jit safe preconditioners/solvers
fieldShape: Tuple[int, int]
isBc: any
isUnknown: any
ids: any
unknownIndices: any
bcIndices: any
dofToUnknown: any
HessRowCoords: any
HessColCoords: any
hessian_bc_mask: any

def __init__(self, functionSpace, dim, EssentialBCs):
self.fieldShape = Mesh.num_nodes(functionSpace.mesh), dim
self.isBc = onp.full(self.fieldShape, False, dtype=bool)
isBc = onp.full(self.fieldShape, False, dtype=bool)
for ebc in EssentialBCs:
self.isBc[functionSpace.mesh.nodeSets[ebc.nodeSet], ebc.component] = True
isBc[functionSpace.mesh.nodeSets[ebc.nodeSet], ebc.component] = True
self.isBc = isBc
self.isUnknown = ~self.isBc

self.ids = np.arange(self.isBc.size).reshape(self.fieldShape)
self.ids = onp.arange(self.isBc.size).reshape(self.fieldShape)

self.unknownIndices = self.ids[self.isUnknown]
self.bcIndices = self.ids[self.isBc]

ones = np.ones(self.isBc.size, dtype=int) * -1
self.dofToUnknown = ones.at[self.unknownIndices].set(np.arange(self.unknownIndices.size))
ones = onp.ones(self.isBc.size, dtype=int) * -1
dofToUnknown = ones
dofToUnknown[self.unknownIndices] = onp.arange(self.unknownIndices.size)
self.dofToUnknown = dofToUnknown

self.HessRowCoords, self.HessColCoords = self._make_hessian_coordinates(functionSpace.mesh.conns)

self.hessian_bc_mask = self._make_hessian_bc_mask(functionSpace.mesh.conns)
self.hessian_bc_mask = self._make_hessian_bc_mask(onp.array(functionSpace.mesh.conns))


def get_bc_size(self):
Expand Down Expand Up @@ -424,7 +450,7 @@ def _make_hessian_coordinates(self, conns):
rowCoords[rangeBegin:rangeEnd] = elHessCoords.ravel()
colCoords[rangeBegin:rangeEnd] = elHessCoords.T.ravel()

rangeBegin += np.square(nElUnknowns[e])
rangeBegin += onp.square(nElUnknowns[e])
return rowCoords, colCoords


Expand Down

0 comments on commit 510afb6

Please sign in to comment.