Skip to content

Commit

Permalink
fem: explicitly set use_canonical_quadrature_point_ordering=False for…
Browse files Browse the repository at this point in the history
… special node types
  • Loading branch information
ksagiyam committed Nov 22, 2024
1 parent 31cca14 commit 1266026
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
36 changes: 23 additions & 13 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@
class ContextBase(ProxyKernelInterface):
"""Common UFL -> GEM translation context."""

keywords = ('ufl_cell',
'fiat_cell',
'integral_type',
'integration_dim',
'entity_ids',
'argument_multiindices',
'facetarea',
'index_cache',
'scalar_type')
keywords = (
'ufl_cell',
'fiat_cell',
'integral_type',
'integration_dim',
'entity_ids',
'argument_multiindices',
'facetarea',
'index_cache',
'scalar_type',
'use_canonical_quadrature_point_ordering',
)

def __init__(self, interface, **kwargs):
ProxyKernelInterface.__init__(self, interface)
Expand Down Expand Up @@ -112,6 +115,9 @@ def translator(self):

@cached_property
def use_canonical_quadrature_point_ordering(self):
# Directly set use_canonical_quadrature_point_ordering = False in context
# for translation of special nodes, e.g., CellVolume, FacetArea, CellOrigin, and CellVertices,
# as quadrature point ordering is not relevant for those node types.
return isinstance(self.fiat_cell, UFCHexahedron) and self.integral_type in ['exterior_facet', 'interior_facet']


Expand Down Expand Up @@ -161,6 +167,7 @@ def jacobian_at(self, point):
expr = NegativeRestricted(expr)
config = {"point_set": PointSingleton(point)}
config.update(self.config)
config.update(use_canonical_quadrature_point_ordering=False) # Not relevant.
context = PointSetContext(**config)
expr = self.preprocess(expr, context)
return map_expr_dag(context.translator, expr)
Expand All @@ -173,6 +180,7 @@ def detJ_at(self, point):
expr = NegativeRestricted(expr)
config = {"point_set": PointSingleton(point)}
config.update(self.config)
config.update(use_canonical_quadrature_point_ordering=False) # Not relevant.
context = PointSetContext(**config)
expr = self.preprocess(expr, context)
return map_expr_dag(context.translator, expr)
Expand Down Expand Up @@ -221,6 +229,7 @@ def physical_edge_lengths(self):
expr = ufl.as_vector([ufl.sqrt(ufl.dot(expr[i, :], expr[i, :])) for i in range(num_edges)])
config = {"point_set": PointSingleton(cell.make_points(sd, 0, sd+1)[0])}
config.update(self.config)
config.update(use_canonical_quadrature_point_ordering=False) # Not relevant.
context = PointSetContext(**config)
expr = self.preprocess(expr, context)
return map_expr_dag(context.translator, expr)
Expand All @@ -243,6 +252,7 @@ def physical_points(self, point_set, entity=None):
if entity is not None:
config.update({name: getattr(self.interface, name)
for name in ["integration_dim", "entity_ids"]})
config.update(use_canonical_quadrature_point_ordering=False) # Not relevant.
context = PointSetContext(**config)
expr = self.preprocess(expr, context)
mapped = map_expr_dag(context.translator, expr)
Expand Down Expand Up @@ -528,7 +538,7 @@ def translate_cellvolume(terminal, mt, ctx):

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "index_cache", "scalar_type"]}
config.update(interface=interface, quadrature_degree=degree)
config.update(interface=interface, quadrature_degree=degree, use_canonical_quadrature_point_ordering=False)
expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True)
return expr

Expand All @@ -542,7 +552,7 @@ def translate_facetarea(terminal, mt, ctx):
config = {name: getattr(ctx, name)
for name in ["ufl_cell", "integration_dim", "scalar_type",
"entity_ids", "index_cache"]}
config.update(interface=ctx, quadrature_degree=degree)
config.update(interface=ctx, quadrature_degree=degree, use_canonical_quadrature_point_ordering=False)
expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True)
return expr

Expand All @@ -556,7 +566,7 @@ def translate_cellorigin(terminal, mt, ctx):

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "index_cache", "scalar_type"]}
config.update(interface=ctx, point_set=point_set)
config.update(interface=ctx, point_set=point_set, use_canonical_quadrature_point_ordering=False)
context = PointSetContext(**config)
return context.translator(expression)

Expand All @@ -569,7 +579,7 @@ def translate_cell_vertices(terminal, mt, ctx):

config = {name: getattr(ctx, name)
for name in ["ufl_cell", "index_cache", "scalar_type"]}
config.update(interface=ctx, point_set=ps)
config.update(interface=ctx, point_set=ps, use_canonical_quadrature_point_ordering=False)
context = PointSetContext(**config)
expr = context.translator(ufl_expr)

Expand Down
9 changes: 6 additions & 3 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def compile_gem(self, ctx):
# Let the kernel interface inspect the optimised IR to register
# what kind of external data is required (e.g., cell orientations,
# cell sizes, etc.).
oriented, needs_cell_sizes, tabulations = self.register_requirements(expressions)
oriented, needs_cell_sizes, tabulations, need_facet_orientation = self.register_requirements(expressions)

# Extract Variables that are actually used
active_variables = gem.extract_type(expressions, gem.Variable)
Expand All @@ -227,7 +227,7 @@ def compile_gem(self, ctx):
impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True)
except impero_utils.NoopError:
impero_c = None
return impero_c, oriented, needs_cell_sizes, tabulations, active_variables
return impero_c, oriented, needs_cell_sizes, tabulations, active_variables, need_facet_orientation

def fem_config(self):
"""Return a dictionary used with fem.compile_ufl.
Expand Down Expand Up @@ -431,6 +431,7 @@ def check_requirements(ir):
in one pass."""
cell_orientations = False
cell_sizes = False
facet_orientation = False
rt_tabs = {}
for node in traversal(ir):
if isinstance(node, gem.Variable):
Expand All @@ -440,7 +441,9 @@ def check_requirements(ir):
cell_sizes = True
elif node.name.startswith("rt_"):
rt_tabs[node.name] = node.shape
return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items()))
elif node.name == "facet_orientation":
facet_orientation = True
return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items())), facet_orientation


def prepare_constant(constant, number):
Expand Down
10 changes: 6 additions & 4 deletions tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def set_coefficient_numbers(self, coefficient_numbers):
def register_requirements(self, ir):
"""Inspect what is referenced by the IR that needs to be
provided by the kernel interface."""
self.oriented, self.cell_sizes, self.tabulations = check_requirements(ir)
self.oriented, self.cell_sizes, self.tabulations, _ = check_requirements(ir)

def set_output(self, o):
"""Produce the kernel return argument"""
Expand Down Expand Up @@ -368,7 +368,7 @@ def construct_kernel(self, name, ctx, log=False):
:arg log: bool if the Kernel should be profiled with Log events
:returns: :class:`Kernel` object
"""
impero_c, oriented, needs_cell_sizes, tabulations, active_variables = self.compile_gem(ctx)
impero_c, oriented, needs_cell_sizes, tabulations, active_variables, need_facet_orientation = self.compile_gem(ctx)
if impero_c is None:
return self.construct_empty_kernel(name)
info = self.integral_data_info
Expand Down Expand Up @@ -418,8 +418,10 @@ def construct_kernel(self, name, ctx, log=False):
elif info.integral_type in ["interior_facet", "interior_facet_vert"]:
int_loopy_arg = lp.GlobalArg("facet", numpy.uint32, shape=(2,))
args.append(kernel_args.InteriorFacetKernelArg(int_loopy_arg))
# Will generalise this in the submesh PR.
if fem.PointSetContext(**self.fem_config()).use_canonical_quadrature_point_ordering:
# The submesh PR will introduce a robust mechanism to check if a Variable
# is actually used in the final form of the expression, so there will be
# no need to get "need_facet_orientation" from self.compile_gem().
if need_facet_orientation:
if info.integral_type == "exterior_facet":
ext_ornt_loopy_arg = lp.GlobalArg("facet_orientation", gem.uint_type, shape=(1,))
args.append(kernel_args.ExteriorFacetOrientationKernelArg(ext_ornt_loopy_arg))
Expand Down

0 comments on commit 1266026

Please sign in to comment.