Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ def _eval_at(self, func):
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
has to be computed at x=x + h_x/2.
"""
# No staggering, don't waste time
if not self.expr.staggered and not func.staggered:
return self
# If an x0 already exists or evaluating at the same function (i.e u = u.dx)
# do not overwrite it
if self.x0 or self.side is not None or func.function is self.expr.function:
Expand Down
114 changes: 80 additions & 34 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,36 @@ def dtype(self):

@cached_property
def indices(self):
return tuple(filter_ordered(flatten(getattr(i, 'indices', ())
for i in self._args_diff)))
if not self._args_diff:
return DimensionTuple()

# Get indices of all args and merge them
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking: blank line

mapper = {}
for a in self._args_diff:
for d, i in a.indices.getters.items():
mapper.setdefault(d, []).append(i)

# Filter unique indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking: blank line

mapper = {k: v[0] if len(v) == 1 else tuple(filter_ordered(v))
for k, v in mapper.items()}

return DimensionTuple(*mapper.values(), getters=tuple(mapper.keys()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking: blank line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking: can just be getters=tuple(mapper)


@cached_property
def dimensions(self):
return tuple(filter_ordered(flatten(getattr(i, 'dimensions', ())
for i in self._args_diff)))
if not self._args_diff:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these two lines? highest_priority(self).dimensions should already return it, or there's something non-consistent about it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because a generic Differentiable does not necessarily has an _args_diff. For. example 1/h_x is a Differentiable but it doesn't have any argument that is Differentiable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the property is still there, it's not that you are exposing yourself to an AttributeError... so (1/h_x)._args_diff will return the empty tuple and it should just be fine no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty tuple doesn't have a .staggered or .dimension property. And on top of it (1/h_x)._args_diff will return 1/h_x which wiull be an infinite recursion

return DimensionTuple()

# Use the staggering of the highest priority function
return highest_priority(self).dimensions

@cached_property
def staggered(self):
if not self._args_diff:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same story as before, I think

... unless I'm missing something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

return None

# Use the staggering of the highest priority function
return highest_priority(self).staggered

@cached_property
def root_dimensions(self):
Expand All @@ -117,11 +140,6 @@ def indices_ref(self):
return DimensionTuple(*self.dimensions, getters=self.dimensions)
return highest_priority(self).indices_ref

@cached_property
def staggered(self):
return tuple(filter_ordered(flatten(getattr(i, 'staggered', ())
for i in self._args_diff)))

@cached_property
def is_Staggered(self):
return any([getattr(i, 'is_Staggered', False) for i in self._args_diff])
Expand Down Expand Up @@ -474,13 +492,21 @@ def has_free(self, *patterns):
return all(i in self.free_symbols for i in patterns)


def highest_priority(DiffOp):
def highest_priority(diff_op):
if not diff_op._args_diff:
return diff_op

# We want to get the object with highest priority
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line

# We also need to make sure that the object with the largest
# set of dimensions is used when multiple ones with the same
# priority appear
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
prio_func = sorted(diff_op._args_diff, key=prio, reverse=True)[0]

# The highest priority must be a Function
if not isinstance(prio_func, AbstractFunction):
return highest_priority(prio_func)
return prio_func


class DifferentiableOp(Differentiable):
Expand Down Expand Up @@ -548,8 +574,11 @@ class DifferentiableFunction(DifferentiableOp):
def __new__(cls, *args, **kwargs):
return cls.__sympy_class__.__new__(cls, *args, **kwargs)

def _eval_at(self, func):
return self
@property
def _fd_priority(self):
if highest_priority(self) is self:
return super()._fd_priority
return highest_priority(self)._fd_priority


class Add(DifferentiableOp, sympy.Add):
Expand Down Expand Up @@ -633,26 +662,12 @@ def _gather_for_diff(self):
if len(set(f.staggered for f in self._args_diff)) == 1:
return self

func_args = highest_priority(self)
new_args = []
ref_inds = func_args.indices_ref.getters

for f in self.args:
if f not in self._args_diff \
or f is func_args \
or isinstance(f, DifferentiableFunction):
new_args.append(f)
else:
ind_f = f.indices_ref.getters
mapper = {ind_f.get(d, d): ref_inds.get(d, d)
for d in self.dimensions
if ind_f.get(d, d) is not ref_inds.get(d, d)}
if mapper:
new_args.append(f.subs(mapper))
else:
new_args.append(f)

return self.func(*new_args, evaluate=False)
derivs, other = split(self.args, lambda a: isinstance(a, sympy.Derivative))
if len(derivs) == 0:
return self._eval_at(highest_priority(self))
else:
other = self.func(*other)._eval_at(highest_priority(self))
return self.func(other, *derivs)


class Pow(DifferentiableOp, sympy.Pow):
Expand Down Expand Up @@ -1034,6 +1049,9 @@ def __new__(cls, *args, base=None, **kwargs):
obj = super().__new__(cls, *args, **kwargs)

try:
if base is obj:
# In some rare cases (rebuild?) base may be obj itself
base = base.base
obj.base = base
except AttributeError:
# This might happen if e.g. one attempts a (re)construction with
Expand Down Expand Up @@ -1061,6 +1079,10 @@ def _eval_at(self, func):
# and should not be re-evaluated at a different location
return self

@property
def indices_ref(self):
return self.base.indices_ref


class diffify:

Expand Down Expand Up @@ -1184,6 +1206,29 @@ def _(expr, x0, **kwargs):
return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs))


@interp_for_fd.register(Mul)
def _(expr, x0, **kwargs):
# For a Mul expression, we interpolate the whole expression
# Do we actually need interpolation
if all(expr.indices[d] is i for d, i in x0.items()):
return expr

# Split args between those that need interp and those that don't
def test0(a):
return all(a.indices[d] is i for d, i in x0.items() if d in a.dimensions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: blank line


oa, ia = split(expr._args_diff,
lambda a: isinstance(a, sympy.Derivative) or test0(a))
oa = oa + tuple(a for a in expr.args if a not in expr._args_diff)

# Interpolate the necessary args
d_dims = tuple((d, 0) for d in x0)
fd_order = tuple(expr.interp_order for d in x0)
iexpr = expr.func(*ia).diff(*d_dims, fd_order=fd_order, x0=x0, **kwargs)

return expr.func(iexpr, *oa)


@interp_for_fd.register(sympy.Expr)
def _(expr, x0, **kwargs):
if expr.args:
Expand All @@ -1194,7 +1239,8 @@ def _(expr, x0, **kwargs):

@interp_for_fd.register(AbstractFunction)
def _(expr, x0, **kwargs):
x0_expr = {d: v for d, v in x0.items() if v.has(d)}
x0_expr = {d: v for d, v in x0.items() if v.has(d)
and expr.indices[d] is not v}
if x0_expr:
return expr.subs({expr.indices[d]: v for d, v in x0_expr.items()})
else:
Expand Down
7 changes: 3 additions & 4 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def check_input(func):
def wrapper(expr, *args, **kwargs):
try:
return S.Zero if expr.is_Number else func(expr, *args, **kwargs)
except AttributeError:
raise ValueError(
f"'{expr}' must be of type Differentiable, not {type(expr)}"
) from None
except Exception as e:
raise type(e)(f"Error while computing finite-difference for expr={expr}: "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reporting the type of expr here would make this error message more useful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean the message is raise type(e) ... that's type already

f"{e}") from e
return wrapper


Expand Down
72 changes: 34 additions & 38 deletions examples/seismic/tti/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def kernel_staggered_2d(model, u, v, **kwargs):
epsilon = 1 + 2 * epsilon
delta = sqrt(1 + 2 * delta)
s = model.grid.stepping_dim.spacing
x, z = model.grid.dimensions

# Get source
qu = kwargs.get('qu', 0)
Expand All @@ -291,31 +290,31 @@ def kernel_staggered_2d(model, u, v, **kwargs):

if forward:
# Stencils
phdx = costheta * u.dx - sintheta * u.dyc
phdx = costheta * u.dx - sintheta * u.dy
u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx)

pvdz = sintheta * v.dxc + costheta * v.dy
pvdz = sintheta * v.dx + costheta * v.dy
u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz)

dvx = costheta * vx.forward.dx - sintheta * vx.forward.dyc
dvz = sintheta * vz.forward.dxc + costheta * vz.forward.dy
dvx = costheta * vx.forward.dx - sintheta * vx.forward.dy
dvz = sintheta * vz.forward.dx + costheta * vz.forward.dy

# u and v equations
pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * dvx + dvz)) + s / m * qv)
ph_eq = Eq(u.forward, dampl * (u - s / m * (epsilon * dvx + delta * dvz)) +
s / m * qu)
else:
# Stencils
phdx = ((costheta*epsilon*u).dx - (sintheta*epsilon*u).dyc +
(costheta*delta*v).dx - (sintheta*delta*v).dyc)
a = epsilon * u + delta * v
phdx = (costheta * a).dx - (sintheta * a).dy
u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx)

pvdz = ((sintheta*delta*u).dxc + (costheta*delta*u).dy +
(sintheta*v).dxc + (costheta*v).dy)
b = delta * u + v
pvdz = (sintheta * b).dx + (costheta * b).dy
u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz)

dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dyc
dvz = (sintheta * vz.backward).dxc + (costheta * vz.backward).dy
dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dy
dvz = (sintheta * vz.backward).dx + (costheta * vz.backward).dy

# u and v equations
pv_eq = Eq(v.backward, dampl * (v + s / m * dvz))
Expand Down Expand Up @@ -356,24 +355,24 @@ def kernel_staggered_3d(model, u, v, **kwargs):
if forward:
# Stencils
phdx = (costheta * cosphi * u.dx +
costheta * sinphi * u.dyc -
sintheta * u.dzc)
costheta * sinphi * u.dy -
sintheta * u.dz)
u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx)

phdy = -sinphi * u.dxc + cosphi * u.dy
phdy = -sinphi * u.dx + cosphi * u.dy
u_vy = Eq(vy.forward, dampl * vy - dampl * s * phdy)

pvdz = (sintheta * cosphi * v.dxc +
sintheta * sinphi * v.dyc +
pvdz = (sintheta * cosphi * v.dx +
sintheta * sinphi * v.dy +
costheta * v.dz)
u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz)

dvx = (costheta * cosphi * vx.forward.dx +
costheta * sinphi * vx.forward.dyc -
sintheta * vx.forward.dzc)
dvy = -sinphi * vy.forward.dxc + cosphi * vy.forward.dy
dvz = (sintheta * cosphi * vz.forward.dxc +
sintheta * sinphi * vz.forward.dyc +
costheta * sinphi * vx.forward.dy -
sintheta * vx.forward.dz)
dvy = -sinphi * vy.forward.dx + cosphi * vy.forward.dy
dvz = (sintheta * cosphi * vz.forward.dx +
sintheta * sinphi * vz.forward.dy +
costheta * vz.forward.dz)
# u and v equations
pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * (dvx + dvy) + dvz)) +
Expand All @@ -383,30 +382,27 @@ def kernel_staggered_3d(model, u, v, **kwargs):
delta * dvz)) + s / m * qu)
else:
# Stencils
phdx = ((costheta * cosphi * epsilon*u).dx +
(costheta * sinphi * epsilon*u).dyc -
(sintheta * epsilon*u).dzc + (costheta * cosphi * delta*v).dx +
(costheta * sinphi * delta*v).dyc -
(sintheta * delta*v).dzc)
a = epsilon * u + delta * v
phdx = ((costheta * cosphi * a).dx +
(costheta * sinphi * a).dy -
(sintheta * a).dz)
u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx)

phdy = (-(sinphi * epsilon*u).dxc + (cosphi * epsilon*u).dy -
(sinphi * delta*v).dxc + (cosphi * delta*v).dy)
phdy = (-(sinphi * a).dx + (cosphi * a).dy)
u_vy = Eq(vy.backward, dampl * vy + dampl * s * phdy)

pvdz = ((sintheta * cosphi * delta*u).dxc +
(sintheta * sinphi * delta*u).dyc +
(costheta * delta*u).dz + (sintheta * cosphi * v).dxc +
(sintheta * sinphi * v).dyc +
(costheta * v).dz)
b = delta * u + v
pvdz = ((sintheta * cosphi * b).dx +
(sintheta * sinphi * b).dy +
(costheta * b).dz)
u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz)

dvx = ((costheta * cosphi * vx.backward).dx +
(costheta * sinphi * vx.backward).dyc -
(sintheta * vx.backward).dzc)
dvy = (-sinphi * vy.backward).dxc + (cosphi * vy.backward).dy
dvz = ((sintheta * cosphi * vz.backward).dxc +
(sintheta * sinphi * vz.backward).dyc +
(costheta * sinphi * vx.backward).dy -
(sintheta * vx.backward).dz)
dvy = (-sinphi * vy.backward).dx + (cosphi * vy.backward).dy
dvz = ((sintheta * cosphi * vz.backward).dx +
(sintheta * sinphi * vz.backward).dy +
(costheta * vz.backward).dz)
# u and v equations
pv_eq = Eq(v.backward, dampl * (v + s / m * dvz))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from devito.finite_differences import Derivative, Differentiable, diffify
from devito.finite_differences.differentiable import (
Add, DiffDerivative, EvalDerivative, IndexDerivative, IndexSum, Weights
Add, DiffDerivative, EvalDerivative, IndexDerivative, IndexSum, Weights, interp_for_fd
)
from devito.symbolics import indexify, retrieve_indexed
from devito.types.dimension import StencilDimension
Expand Down Expand Up @@ -921,7 +921,7 @@ def test_param_stagg_add(self):
assert simplify(eq0.evaluate.rhs - expect0) == 0

# Expects to evaluate c11 and txy at xp then the derivative at yp
expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate
expect1 = (interp_for_fd((c11 * txx), {x: xp}).evaluate).dy.evaluate
assert simplify(eq1.evaluate.rhs - expect1) == 0

# Addition should apply the same logic as above for each term
Expand Down
Loading