Skip to content

Commit

Permalink
Merge pull request #2519 from devitocodes/maybe-fix-coeffs
Browse files Browse the repository at this point in the history
compiler: Tweak custom coefficients error handling
  • Loading branch information
mloubout authored Jan 20, 2025
2 parents da2c9a4 + 35d0142 commit fa903e4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
27 changes: 16 additions & 11 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational

from devito.logger import warning
from devito.tools import Tag, as_tuple
from devito.types.dimension import StencilDimension

Expand Down Expand Up @@ -260,6 +261,18 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
-------
An IndexSet, representing an ordered list of indices.
"""
# Check size of input weights
if nweights > 0:
do, dw = order + 1 + order % 2, nweights
if do < dw:
raise ValueError(f"More weights ({nweights}) provided than the maximum"
f"stencil size ({order + 1}) for order {order} scheme")
elif do > dw:
warning(f"Less weights ({nweights}) provided than the stencil size"
f"({order + 1}) for order {order} scheme."
" Reducing order to {nweights//2}")
order = nweights - nweights % 2

# Evaluation point
x0 = sympify(((x0 or {}).get(dim) or expr.indices_ref[dim]))

Expand All @@ -276,23 +289,15 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
side = side or centered

# Indices range
o_min = int(np.ceil(mid - order/2)) + side.val
o_max = int(np.floor(mid + order/2)) + side.val
r = (nweights or order) / 2
o_min = int(np.ceil(mid - r)) + side.val
o_max = int(np.floor(mid + r)) + side.val
if o_max == o_min:
if dim.is_Time or not expr.is_Staggered:
o_max += 1
else:
o_min -= 1

if nweights > 0 and (o_max - o_min + 1) != nweights:
# We cannot infer how the stencil should be centered
# if nweights is more than one extra point.
assert nweights == (o_max - o_min + 1) + 1
# In the "one extra" case we need to pad with one point to symmetrize
if (o_max - mid) > (mid - o_min):
o_min -= 1
else:
o_max += 1
# StencilDimension and expression
d = make_stencil_dimension(expr, o_min, o_max)
iexpr = expr.indices_ref[dim] + d * dim.spacing
Expand Down
31 changes: 30 additions & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from conftest import assert_structure, get_params, get_arrays, check_array
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
cos, sin)
Coefficient, Substitutions, cos, sin)
from devito.finite_differences import Weights
from devito.arch.compiler import OneapiCompiler
from devito.ir import Expression, FindNodes, FindSymbols
Expand Down Expand Up @@ -76,6 +76,35 @@ def test_multiple_cross_derivs(self, coeffs, expected):
weights = {f for f in functions if isinstance(f, Weights)}
assert len(weights) == expected

@pytest.mark.parametrize('order', [1, 2])
@pytest.mark.parametrize('nweight', [None, +4, -4])
def test_legacy_api(self, order, nweight):
grid = Grid(shape=(51, 51, 51))
x, y, z = grid.dimensions

nweight = 0 if nweight is None else nweight
so = 8

u = TimeFunction(name='u', grid=grid, space_order=so,
coefficients='symbolic')

w0 = np.arange(so + 1 + nweight) + 1
wstr = '{' + ', '.join([f"{w:1.1f}F" for w in w0]) + '}'
wdef = f'[{so + 1 + nweight}] __attribute__ ((aligned (64)))'

coeffs_x_p1 = Coefficient(order, u, x, w0)

coeffs = Substitutions(coeffs_x_p1)

eqn = Eq(u, u.dx.dy + u.dx2 + .37, coefficients=coeffs)

if nweight > 0:
with pytest.raises(ValueError):
op = Operator(eqn, opt=('advanced', {'expand': False}))
else:
op = Operator(eqn, opt=('advanced', {'expand': False}))
assert f'{wdef} = {wstr}' in str(op)


class Test1Pass:

Expand Down

0 comments on commit fa903e4

Please sign in to comment.