Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def init_petsc():
from firedrake.deflation import DeflatedSNES, Deflation # noqa: F401
from firedrake.exceptions import ( # noqa: F401
FiredrakeException, ConvergenceError, MismatchingDomainError,
VertexOnlyMeshMissingPointsError, DofNotDefinedError
VertexOnlyMeshMissingPointsError, DofNotDefinedError, DofTypeError,
)
from firedrake.function import ( # noqa: F401
Function, PointNotInDomainError,
Expand Down
6 changes: 6 additions & 0 deletions firedrake/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ class DofNotDefinedError(FiredrakeException):
"""


class DofTypeError(FiredrakeException):
"""Raised when an operation is attempted on a degree of freedom (DoF)
type which is not supported.
"""


class VertexOnlyMeshMissingPointsError(FiredrakeException):
"""Exception raised when 1 or more points are not found by a
:func:`~.VertexOnlyMesh` in its parent mesh.
Expand Down
205 changes: 128 additions & 77 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from FIAT.reference_element import Point

from finat.element_factory import create_element, as_fiat_cell
from finat.ufl import TensorElement, VectorElement, MixedElement
from finat.ufl import TensorElement, VectorElement, MixedElement, FiniteElementBase
from finat.fiat_elements import ScalarFiatElement
from finat.quadrature import QuadratureRule
from finat.quadrature_element import QuadratureElement
Expand All @@ -36,7 +36,7 @@

from firedrake.utils import IntType, ScalarType, cached_property, known_pyop2_safe, tuplify
from firedrake.tsfc_interface import extract_numbered_coefficients, _cachedir
from firedrake.ufl_expr import Argument, Coargument, action
from firedrake.ufl_expr import Argument, Coargument, TrialFunction, TestFunction, action
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshTopology, MeshGeometry, MeshTopology, VertexOnlyMesh
from firedrake.petsc import PETSc
from firedrake.halo import _get_mtype
Expand All @@ -49,7 +49,8 @@
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.exceptions import (
DofNotDefinedError, VertexOnlyMeshMissingPointsError, NonUniqueMeshSequenceError
DofNotDefinedError, VertexOnlyMeshMissingPointsError, NonUniqueMeshSequenceError,
DofTypeError,
)

from mpi4py import MPI
Expand Down Expand Up @@ -426,17 +427,6 @@ def __init__(self, expr: Interpolate):
else:
self.access = op2.WRITE

# TODO check V.finat_element.is_lagrange() once https://github.com/firedrakeproject/fiat/pull/200 is released
target_element = self.target_space.ufl_element()
if not ((isinstance(target_element, MixedElement)
and all(sub.mapping() == "identity" for sub in target_element.sub_elements))
or target_element.mapping() == "identity"):
# Identity mapping between reference cell and physical coordinates
# implies point evaluation nodes.
raise NotImplementedError(
"Can only cross-mesh interpolate into spaces with point evaluation nodes."
)

if self.allow_missing_dofs:
self.missing_points_behaviour = MissingPointsBehaviour.IGNORE
else:
Expand All @@ -445,26 +435,62 @@ def __init__(self, expr: Interpolate):
if self.source_mesh.geometric_dimension != self.target_mesh.geometric_dimension:
raise ValueError("Geometric dimensions of source and destination meshes must match.")

dest_element = self.target_space.ufl_element()
def _get_element(self, V: WithGeometry) -> FiniteElementBase:
"""Return the element of the function space V. If V is tensor/vector valued,
return the base scalar element.

Parameters
----------
V
A :class:`.WithGeometry` function space.

Returns
-------
FiniteElementBase
The base element of V.
"""
dest_element = V.ufl_element()
if isinstance(dest_element, MixedElement):
if isinstance(dest_element, VectorElement | TensorElement):
# In this case all sub elements are equal
base_element = dest_element.sub_elements[0]
if base_element.reference_value_shape != ():
raise NotImplementedError(
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements "
"or TensorElements made from sub elements with value shape other than ()."
)
self.dest_element = base_element
return dest_element.sub_elements[0]
else:
raise NotImplementedError("Interpolation with MixedFunctionSpace requires MixedInterpolator.")
else:
# scalar fiat/finat element
self.dest_element = dest_element
return dest_element

def _fs_type(self, V: WithGeometry) -> Callable[..., WithGeometry]:
"""Returns a callable which returns a function space matching the type of V.

Parameters
----------
V
A :class:`.WithGeometry` function space.

Returns
-------
Callable
A callable which returns a :class:`.WithGeometry` matching the type of V.
"""
# Get the correct type of function space
shape = V.value_shape
if len(shape) == 0:
return FunctionSpace
elif len(shape) == 1:
return partial(VectorFunctionSpace, dim=shape[0])
else:
symmetry = V.ufl_element().symmetry()
Copy link
Contributor

@pbrubeck pbrubeck Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment, the Quadrature space may or not enforce the symmetry, but it is always preferable to enforce it. This means that VomOntoVomInterpolator needs to support interpolation of symmetric tensors.

Also related to this, some elements are intrinsically symmetric-valued (all basis functions are symmetric) even when the ufl_element().symmetry() is None, for example HHJ and Regge. This is similar to the case of intrinsically vector-valued elements like RT, where the basis functions are vector-valued but they don't need a VectorElement.

What I am trying to say with this is that when we have intrisic symmetry we can always reduce the size of the Quadrature space by further inspecting the finat elements. But for now, the current approach to construct the Quadrature space seems fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we turn this into an issue.

return partial(TensorFunctionSpace, shape=shape, symmetry=symmetry)

def _get_symbolic_expressions(self, target_space: WithGeometry) -> tuple[Interpolate, Interpolate]:
"""Return symbolic ``Interpolate`` expressions for point evaluation of the `target_space`s
dofs in the source mesh, and the corresponding input-ordering interpolation.

def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
"""Return the symbolic ``Interpolate`` expressions for point evaluation and
re-ordering into the input-ordering VertexOnlyMesh.
Parameters
----------
target_space
The :class:`.WithGeometry` function space which we are interpolating into.

Returns
-------
Expand All @@ -474,14 +500,20 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:

Raises
------
DofNotDefinedError
If any DoFs in the target mesh cannot be defined in the source mesh.
DoFNotDefinedError
If any of the target spaces dofs cannot be defined in the source mesh.
DoFTypeError
If the target space does not have point-evaluation dofs.
"""
from firedrake.assemble import assemble
if not target_space.finat_element.has_pointwise_dual_basis:
raise DofTypeError(f"FunctionSpace {target_space} must have point-evaluation dofs.")

# Immerse coordinates of target space point evaluation dofs in src_mesh
target_space_vec = VectorFunctionSpace(self.target_mesh, self.dest_element)
f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec))
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension)
target_mesh = target_space.mesh().unique()
target_space_vec = VectorFunctionSpace(target_mesh, self._get_element(target_space))
f_dest_node_coords = assemble(interpolate(target_mesh.coordinates, target_space_vec))
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, target_mesh.geometric_dimension)
try:
vom = VertexOnlyMesh(
self.source_mesh,
Expand All @@ -496,17 +528,8 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
"This may be because the target mesh covers a larger domain than the "
"source mesh. To disable this error, set allow_missing_dofs=True.")

# Get the correct type of function space
shape = self.target_space.ufl_function_space().value_shape
if len(shape) == 0:
fs_type = FunctionSpace
elif len(shape) == 1:
fs_type = partial(VectorFunctionSpace, dim=shape[0])
else:
symmetry = self.target_space.ufl_element().symmetry()
fs_type = partial(TensorFunctionSpace, shape=shape, symmetry=symmetry)

# Get expression for point evaluation at the dest_node_coords
fs_type = self._fs_type(target_space)
P0DG_vom = fs_type(vom, "DG", 0)
point_eval = interpolate(self.operand, P0DG_vom)

Expand All @@ -523,11 +546,15 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None)
raise NotImplementedError("bcs not implemented for cross-mesh interpolation.")
mat_type = mat_type or "aij"

# self.ufl_interpolate.function_space() is None in the 0-form case
V_dest = self.ufl_interpolate.function_space() or self.target_space
f = tensor or Function(V_dest)
# Interpolate into intermediate quadrature space for non-identity mapped elements
if into_quadrature_space := not self.target_space.finat_element.has_pointwise_dual_basis:
target_space = self.target_space.quadrature_space()
f = Function(target_space.dual() if self.ufl_interpolate.is_adjoint else target_space)
else:
target_space = self.target_space
f = tensor or Function(self.ufl_interpolate.function_space() or target_space)

point_eval, point_eval_input_ordering = self._get_symbolic_expressions()
point_eval, point_eval_input_ordering = self._get_symbolic_expressions(target_space)
P0DG_vom_input_ordering = point_eval_input_ordering.argument_slots()[0].function_space().dual()

if self.rank == 2:
Expand All @@ -536,47 +563,65 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None)
# `self.point_eval_interpolate` and the permutation
# given by `self.to_input_ordering_interpolate`.
if self.ufl_interpolate.is_adjoint:
symbolic = action(point_eval, point_eval_input_ordering)
interp_expr = action(point_eval, point_eval_input_ordering)
else:
symbolic = action(point_eval_input_ordering, point_eval)
interp_expr = action(point_eval_input_ordering, point_eval)

def callable() -> PETSc.Mat:
return assemble(symbolic, mat_type=mat_type).petscmat
res = assemble(interp_expr, mat_type=mat_type).petscmat
if into_quadrature_space:
source_space = self.ufl_interpolate.function_space()
if self.ufl_interpolate.is_adjoint:
I = AssembledMatrix((Argument(source_space, 0), Argument(target_space.dual(), 1)), None, res)
return assemble(action(I, interpolate(TestFunction(target_space), self.target_space))).petscmat
else:
I = AssembledMatrix((Argument(target_space.dual(), 0), Argument(source_space, 1)), None, res)
return assemble(action(interpolate(TrialFunction(target_space), self.target_space), I)).petscmat
else:
return res

elif self.ufl_interpolate.is_adjoint:
assert self.rank == 1
# f_src is a cofunction on V_dest.dual
cofunc = self.dual_arg
assert isinstance(cofunc, Cofunction)

# Our first adjoint operation is to assign the dat values to a
# P0DG cofunction on our input ordering VOM.
f_input_ordering = Cofunction(P0DG_vom_input_ordering.dual())
f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:]

# The rest of the adjoint interpolation is the composition
# of the adjoint interpolators in the reverse direction.
# We don't worry about skipping over missing points here
# because we're going from the input ordering VOM to the original VOM
# and all points from the input ordering VOM are in the original.

def callable() -> Cofunction:
if into_quadrature_space:
cofunc = assemble(interpolate(TestFunction(target_space), self.dual_arg))
f_target = Cofunction(point_eval.function_space())
else:
cofunc = self.dual_arg
f_target = f

assert isinstance(cofunc, Cofunction)

# Our first adjoint operation is to assign the dat values to a
# P0DG cofunction on our input ordering VOM.
f_input_ordering = Cofunction(P0DG_vom_input_ordering.dual())
f_input_ordering.dat.data_wo[:] = cofunc.dat.data_ro[:]

# The rest of the adjoint interpolation is the composition
# of the adjoint interpolators in the reverse direction.
# We don't worry about skipping over missing points here
# because we're going from the input ordering VOM to the original VOM
# and all points from the input ordering VOM are in the original.
f_src_at_src_node_coords = assemble(action(point_eval_input_ordering, f_input_ordering))
assemble(action(point_eval, f_src_at_src_node_coords), tensor=f)
return f
assemble(action(point_eval, f_src_at_src_node_coords), tensor=f_target)
return f_target
else:
assert self.rank in {0, 1}
# We create the input-ordering Function before interpolating so we can
# set default missing values if required.
f_point_eval_input_ordering = Function(P0DG_vom_input_ordering)
if self.default_missing_val is not None:
f_point_eval_input_ordering.assign(self.default_missing_val)
elif self.allow_missing_dofs:
# If we allow missing points there may be points in the target
# mesh that are not in the source mesh. If we don't specify a
# default missing value we set these to NaN so we can identify
# them later.
f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan

def callable() -> Function | Number:
# We create the input-ordering Function before interpolating so we can
# set default missing values if required.
f_point_eval_input_ordering = Function(P0DG_vom_input_ordering)
if self.default_missing_val is not None:
f_point_eval_input_ordering.assign(self.default_missing_val)
elif self.allow_missing_dofs:
# If we allow missing points there may be points in the target
# mesh that are not in the source mesh. If we don't specify a
# default missing value we set these to NaN so we can identify
# them later.
f_point_eval_input_ordering.dat.data_wo[:] = numpy.nan

assemble(action(point_eval_input_ordering, point_eval), tensor=f_point_eval_input_ordering)
# We assign these values to the output function
if self.allow_missing_dofs and self.default_missing_val is None:
Expand All @@ -585,12 +630,18 @@ def callable() -> Function | Number:
else:
f.dat.data_wo[:] = f_point_eval_input_ordering.dat.data_ro[:]

if into_quadrature_space:
f_target = Function(self.target_space)
assemble(interpolate(f, self.target_space), tensor=f_target)
else:
f_target = f

if self.rank == 0:
# We take the action of the dual_arg on the interpolated function
assert isinstance(self.dual_arg, Cofunction)
return assemble(action(self.dual_arg, f))
return assemble(action(self.dual_arg, f_target))
else:
return f
return f_target
return callable

@property
Expand Down
Loading
Loading