diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 68a92b2b0b..22984af6fb 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -478,6 +478,9 @@ def _eval_at(self, func): setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx has to be computed at x=x + h_x/2. """ + # No staggering, don't waste time + if not self.expr.staggered and not func.staggered: + return self # If an x0 already exists or evaluating at the same function (i.e u = u.dx) # do not overwrite it if self.x0 or self.side is not None or func.function is self.expr.function: diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 63f5e9e5df..503393b089 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -95,13 +95,36 @@ def dtype(self): @cached_property def indices(self): - return tuple(filter_ordered(flatten(getattr(i, 'indices', ()) - for i in self._args_diff))) + if not self._args_diff: + return DimensionTuple() + + # Get indices of all args and merge them + mapper = {} + for a in self._args_diff: + for d, i in a.indices.getters.items(): + mapper.setdefault(d, []).append(i) + + # Filter unique indices + mapper = {k: v[0] if len(v) == 1 else tuple(filter_ordered(v)) + for k, v in mapper.items()} + + return DimensionTuple(*mapper.values(), getters=tuple(mapper.keys())) @cached_property def dimensions(self): - return tuple(filter_ordered(flatten(getattr(i, 'dimensions', ()) - for i in self._args_diff))) + if not self._args_diff: + return DimensionTuple() + + # Use the staggering of the highest priority function + return highest_priority(self).dimensions + + @cached_property + def staggered(self): + if not self._args_diff: + return None + + # Use the staggering of the highest priority function + return highest_priority(self).staggered @cached_property def root_dimensions(self): @@ -117,11 +140,6 @@ def indices_ref(self): return DimensionTuple(*self.dimensions, getters=self.dimensions) return highest_priority(self).indices_ref - @cached_property - def staggered(self): - return tuple(filter_ordered(flatten(getattr(i, 'staggered', ()) - for i in self._args_diff))) - @cached_property def is_Staggered(self): return any([getattr(i, 'is_Staggered', False) for i in self._args_diff]) @@ -474,13 +492,21 @@ def has_free(self, *patterns): return all(i in self.free_symbols for i in patterns) -def highest_priority(DiffOp): +def highest_priority(diff_op): + if not diff_op._args_diff: + return diff_op + # We want to get the object with highest priority # We also need to make sure that the object with the largest # set of dimensions is used when multiple ones with the same # priority appear prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions)) - return sorted(DiffOp._args_diff, key=prio, reverse=True)[0] + prio_func = sorted(diff_op._args_diff, key=prio, reverse=True)[0] + + # The highest priority must be a Function + if not isinstance(prio_func, AbstractFunction): + return highest_priority(prio_func) + return prio_func class DifferentiableOp(Differentiable): @@ -548,8 +574,11 @@ class DifferentiableFunction(DifferentiableOp): def __new__(cls, *args, **kwargs): return cls.__sympy_class__.__new__(cls, *args, **kwargs) - def _eval_at(self, func): - return self + @property + def _fd_priority(self): + if highest_priority(self) is self: + return super()._fd_priority + return highest_priority(self)._fd_priority class Add(DifferentiableOp, sympy.Add): @@ -633,26 +662,12 @@ def _gather_for_diff(self): if len(set(f.staggered for f in self._args_diff)) == 1: return self - func_args = highest_priority(self) - new_args = [] - ref_inds = func_args.indices_ref.getters - - for f in self.args: - if f not in self._args_diff \ - or f is func_args \ - or isinstance(f, DifferentiableFunction): - new_args.append(f) - else: - ind_f = f.indices_ref.getters - mapper = {ind_f.get(d, d): ref_inds.get(d, d) - for d in self.dimensions - if ind_f.get(d, d) is not ref_inds.get(d, d)} - if mapper: - new_args.append(f.subs(mapper)) - else: - new_args.append(f) - - return self.func(*new_args, evaluate=False) + derivs, other = split(self.args, lambda a: isinstance(a, sympy.Derivative)) + if len(derivs) == 0: + return self._eval_at(highest_priority(self)) + else: + other = self.func(*other)._eval_at(highest_priority(self)) + return self.func(other, *derivs) class Pow(DifferentiableOp, sympy.Pow): @@ -1034,6 +1049,9 @@ def __new__(cls, *args, base=None, **kwargs): obj = super().__new__(cls, *args, **kwargs) try: + if base is obj: + # In some rare cases (rebuild?) base may be obj itself + base = base.base obj.base = base except AttributeError: # This might happen if e.g. one attempts a (re)construction with @@ -1061,6 +1079,10 @@ def _eval_at(self, func): # and should not be re-evaluated at a different location return self + @property + def indices_ref(self): + return self.base.indices_ref + class diffify: @@ -1184,6 +1206,29 @@ def _(expr, x0, **kwargs): return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs)) +@interp_for_fd.register(Mul) +def _(expr, x0, **kwargs): + # For a Mul expression, we interpolate the whole expression + # Do we actually need interpolation + if all(expr.indices[d] is i for d, i in x0.items()): + return expr + + # Split args between those that need interp and those that don't + def test0(a): + return all(a.indices[d] is i for d, i in x0.items() if d in a.dimensions) + + oa, ia = split(expr._args_diff, + lambda a: isinstance(a, sympy.Derivative) or test0(a)) + oa = oa + tuple(a for a in expr.args if a not in expr._args_diff) + + # Interpolate the necessary args + d_dims = tuple((d, 0) for d in x0) + fd_order = tuple(expr.interp_order for d in x0) + iexpr = expr.func(*ia).diff(*d_dims, fd_order=fd_order, x0=x0, **kwargs) + + return expr.func(iexpr, *oa) + + @interp_for_fd.register(sympy.Expr) def _(expr, x0, **kwargs): if expr.args: @@ -1194,7 +1239,8 @@ def _(expr, x0, **kwargs): @interp_for_fd.register(AbstractFunction) def _(expr, x0, **kwargs): - x0_expr = {d: v for d, v in x0.items() if v.has(d)} + x0_expr = {d: v for d, v in x0.items() if v.has(d) + and expr.indices[d] is not v} if x0_expr: return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()}) else: diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 3fe9b4f4ab..69e66ce4e6 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -50,10 +50,9 @@ def check_input(func): def wrapper(expr, *args, **kwargs): try: return S.Zero if expr.is_Number else func(expr, *args, **kwargs) - except AttributeError: - raise ValueError( - f"'{expr}' must be of type Differentiable, not {type(expr)}" - ) from None + except Exception as e: + raise type(e)(f"Error while computing finite-difference for expr={expr}: " + f"{e}") from e return wrapper diff --git a/examples/seismic/tti/operators.py b/examples/seismic/tti/operators.py index 7ce74f1663..f6ec8d34bb 100644 --- a/examples/seismic/tti/operators.py +++ b/examples/seismic/tti/operators.py @@ -280,7 +280,6 @@ def kernel_staggered_2d(model, u, v, **kwargs): epsilon = 1 + 2 * epsilon delta = sqrt(1 + 2 * delta) s = model.grid.stepping_dim.spacing - x, z = model.grid.dimensions # Get source qu = kwargs.get('qu', 0) @@ -291,14 +290,14 @@ def kernel_staggered_2d(model, u, v, **kwargs): if forward: # Stencils - phdx = costheta * u.dx - sintheta * u.dyc + phdx = costheta * u.dx - sintheta * u.dy u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx) - pvdz = sintheta * v.dxc + costheta * v.dy + pvdz = sintheta * v.dx + costheta * v.dy u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz) - dvx = costheta * vx.forward.dx - sintheta * vx.forward.dyc - dvz = sintheta * vz.forward.dxc + costheta * vz.forward.dy + dvx = costheta * vx.forward.dx - sintheta * vx.forward.dy + dvz = sintheta * vz.forward.dx + costheta * vz.forward.dy # u and v equations pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * dvx + dvz)) + s / m * qv) @@ -306,16 +305,16 @@ def kernel_staggered_2d(model, u, v, **kwargs): s / m * qu) else: # Stencils - phdx = ((costheta*epsilon*u).dx - (sintheta*epsilon*u).dyc + - (costheta*delta*v).dx - (sintheta*delta*v).dyc) + a = epsilon * u + delta * v + phdx = (costheta * a).dx - (sintheta * a).dy u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx) - pvdz = ((sintheta*delta*u).dxc + (costheta*delta*u).dy + - (sintheta*v).dxc + (costheta*v).dy) + b = delta * u + v + pvdz = (sintheta * b).dx + (costheta * b).dy u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz) - dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dyc - dvz = (sintheta * vz.backward).dxc + (costheta * vz.backward).dy + dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dy + dvz = (sintheta * vz.backward).dx + (costheta * vz.backward).dy # u and v equations pv_eq = Eq(v.backward, dampl * (v + s / m * dvz)) @@ -356,24 +355,24 @@ def kernel_staggered_3d(model, u, v, **kwargs): if forward: # Stencils phdx = (costheta * cosphi * u.dx + - costheta * sinphi * u.dyc - - sintheta * u.dzc) + costheta * sinphi * u.dy - + sintheta * u.dz) u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx) - phdy = -sinphi * u.dxc + cosphi * u.dy + phdy = -sinphi * u.dx + cosphi * u.dy u_vy = Eq(vy.forward, dampl * vy - dampl * s * phdy) - pvdz = (sintheta * cosphi * v.dxc + - sintheta * sinphi * v.dyc + + pvdz = (sintheta * cosphi * v.dx + + sintheta * sinphi * v.dy + costheta * v.dz) u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz) dvx = (costheta * cosphi * vx.forward.dx + - costheta * sinphi * vx.forward.dyc - - sintheta * vx.forward.dzc) - dvy = -sinphi * vy.forward.dxc + cosphi * vy.forward.dy - dvz = (sintheta * cosphi * vz.forward.dxc + - sintheta * sinphi * vz.forward.dyc + + costheta * sinphi * vx.forward.dy - + sintheta * vx.forward.dz) + dvy = -sinphi * vy.forward.dx + cosphi * vy.forward.dy + dvz = (sintheta * cosphi * vz.forward.dx + + sintheta * sinphi * vz.forward.dy + costheta * vz.forward.dz) # u and v equations pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * (dvx + dvy) + dvz)) + @@ -383,30 +382,27 @@ def kernel_staggered_3d(model, u, v, **kwargs): delta * dvz)) + s / m * qu) else: # Stencils - phdx = ((costheta * cosphi * epsilon*u).dx + - (costheta * sinphi * epsilon*u).dyc - - (sintheta * epsilon*u).dzc + (costheta * cosphi * delta*v).dx + - (costheta * sinphi * delta*v).dyc - - (sintheta * delta*v).dzc) + a = epsilon * u + delta * v + phdx = ((costheta * cosphi * a).dx + + (costheta * sinphi * a).dy - + (sintheta * a).dz) u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx) - phdy = (-(sinphi * epsilon*u).dxc + (cosphi * epsilon*u).dy - - (sinphi * delta*v).dxc + (cosphi * delta*v).dy) + phdy = (-(sinphi * a).dx + (cosphi * a).dy) u_vy = Eq(vy.backward, dampl * vy + dampl * s * phdy) - pvdz = ((sintheta * cosphi * delta*u).dxc + - (sintheta * sinphi * delta*u).dyc + - (costheta * delta*u).dz + (sintheta * cosphi * v).dxc + - (sintheta * sinphi * v).dyc + - (costheta * v).dz) + b = delta * u + v + pvdz = ((sintheta * cosphi * b).dx + + (sintheta * sinphi * b).dy + + (costheta * b).dz) u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz) dvx = ((costheta * cosphi * vx.backward).dx + - (costheta * sinphi * vx.backward).dyc - - (sintheta * vx.backward).dzc) - dvy = (-sinphi * vy.backward).dxc + (cosphi * vy.backward).dy - dvz = ((sintheta * cosphi * vz.backward).dxc + - (sintheta * sinphi * vz.backward).dyc + + (costheta * sinphi * vx.backward).dy - + (sintheta * vx.backward).dz) + dvy = (-sinphi * vy.backward).dx + (cosphi * vy.backward).dy + dvz = ((sintheta * cosphi * vz.backward).dx + + (sintheta * sinphi * vz.backward).dy + (costheta * vz.backward).dz) # u and v equations pv_eq = Eq(v.backward, dampl * (v + s / m * dvz)) diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index a18771d245..8909d5a9fa 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -9,7 +9,7 @@ ) from devito.finite_differences import Derivative, Differentiable, diffify from devito.finite_differences.differentiable import ( - Add, DiffDerivative, EvalDerivative, IndexDerivative, IndexSum, Weights + Add, DiffDerivative, EvalDerivative, IndexDerivative, IndexSum, Weights, interp_for_fd ) from devito.symbolics import indexify, retrieve_indexed from devito.types.dimension import StencilDimension @@ -921,7 +921,7 @@ def test_param_stagg_add(self): assert simplify(eq0.evaluate.rhs - expect0) == 0 # Expects to evaluate c11 and txy at xp then the derivative at yp - expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate + expect1 = (interp_for_fd((c11 * txx), {x: xp}).evaluate).dy.evaluate assert simplify(eq1.evaluate.rhs - expect1) == 0 # Addition should apply the same logic as above for each term