From 51898afc0b7d850b2028877c8c100425721585cc Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Mon, 18 Nov 2024 21:25:53 +0200 Subject: [PATCH 1/2] primitives: improve compat with pre-dataclass expressions --- pytential/symbolic/primitives.py | 60 ++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index 459f75b77..5b61c6778 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -20,10 +20,11 @@ THE SOFTWARE. """ +from collections.abc import Iterable from dataclasses import field from warnings import warn from functools import partial -from typing import Any, Union, Literal +from typing import Any, Literal, Union import numpy as np @@ -520,7 +521,9 @@ class NodeCoordinateComponent(DiscretizationProperty): """The axis index this node coordinate represents, i.e. 0 for $x$, etc.""" # FIXME: this is added for backwards compatibility with pre-dataclass expressions - def __init__(self, ambient_axis: int, dofdesc: DOFDescriptorLike) -> None: + def __init__(self, + ambient_axis: int, + dofdesc: DOFDescriptorLike | None = None) -> None: object.__setattr__(self, "ambient_axis", ambient_axis) super().__init__(dofdesc) # type: ignore[arg-type] @@ -578,7 +581,7 @@ def make_op(operand_i): def __init__(self, ref_axes: tuple[tuple[int, int], ...], operand: Expression, - dofdesc: DOFDescriptorLike) -> None: + dofdesc: DOFDescriptorLike | None = None) -> None: object.__setattr__(self, "ref_axes", ref_axes) object.__setattr__(self, "operand", operand) super().__init__(dofdesc) # type: ignore[arg-type] @@ -1236,7 +1239,9 @@ class SingleScalarOperandExpressionWithWhere(Expression): operand: Operand """An expression or an array on which to apply the operation.""" - dofdesc: DOFDescriptor + + # pylint: disable-next=invalid-field-call + dofdesc: DOFDescriptor = field(default_factory=lambda: DEFAULT_DOFDESC) """The descriptor for the geometry where the *operand* is defined.""" def __new__(cls, @@ -1347,9 +1352,13 @@ class IterativeInverse(Expression): """The right-hand side variable used in the linear solve.""" variable_name: str """The name of the variable to solve for.""" - extra_vars: dict[str, Variable] + + # pylint: disable-next=invalid-field-call + extra_vars: dict[str, Variable] = field(default_factory=dict) """A dictionary of additional variables required to define the operator.""" - dofdesc: DOFDescriptor + + # pylint: disable-next=invalid-field-call + dofdesc: DOFDescriptor = field(default_factory=lambda: DEFAULT_DOFDESC) """A descriptor for the geometry on which the solution is defined.""" def __post_init__(self) -> None: @@ -1450,7 +1459,7 @@ def hashable_kernel_args(kernel_arguments): return tuple(hashable_args) -@expr_dataclass(hash=False) +@expr_dataclass(init=False, hash=False) class IntG(Expression): r""" .. math:: @@ -1528,6 +1537,43 @@ class IntG(Expression): them. """ + def __init__( + self, + target_kernel: Kernel, + source_kernels: Iterable[Kernel], + densities: Iterable[Expression], + qbx_forced_limit: QBXForcedLimit, + source: DOFDescriptorLike | None = None, + target: DOFDescriptorLike | None = None, + kernel_arguments: dict[str, Any] | None = None, + **kwargs: Any + ) -> None: + if kernel_arguments is None: + kernel_arguments = {} + + if kwargs: + warn(f"Passing named '**kwargs' to {type(self).__name__!r} is " + "deprecated and will result in an error in 2025. Use the " + "'kernel_arguments' argument instead.", + DeprecationWarning, stacklevel=2) + + kernel_arguments = kernel_arguments.copy() + for name, value in kwargs.items(): + if name in kernel_arguments: + raise ValueError(f"'{name}' already set in 'kernel_arguments'") + + kernel_arguments[name] = value + + object.__setattr__(self, "target_kernel", target_kernel) + object.__setattr__(self, "source_kernels", source_kernels) + object.__setattr__(self, "densities", densities) + object.__setattr__(self, "qbx_forced_limit", qbx_forced_limit) + object.__setattr__(self, "source", source) + object.__setattr__(self, "target", target) + object.__setattr__(self, "kernel_arguments", kernel_arguments) + + self.__post_init__() + def __post_init__(self) -> None: if self.qbx_forced_limit not in {-1, +1, -2, +2, "avg", None}: raise ValueError( From baba46fa9af2222a69b1816438274a443e1baa32 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Mon, 18 Nov 2024 21:27:25 +0200 Subject: [PATCH 2/2] mappers: fix int_g changed flag --- pytential/symbolic/mappers.py | 2 +- test/test_symbolic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytential/symbolic/mappers.py b/pytential/symbolic/mappers.py index 13f3e78d1..7e9aaa7f6 100644 --- a/pytential/symbolic/mappers.py +++ b/pytential/symbolic/mappers.py @@ -61,7 +61,7 @@ def rec_int_g_arguments(mapper, expr): name: mapper.rec(arg) for name, arg in expr.kernel_arguments.items() } - changed = ( + changed = not ( all(d is orig for d, orig in zip(densities, expr.densities, strict=True)) and all( arg is orig for arg, orig in zip( diff --git a/test/test_symbolic.py b/test/test_symbolic.py index cb24fd10b..551664f78 100644 --- a/test/test_symbolic.py +++ b/test/test_symbolic.py @@ -500,7 +500,7 @@ def test_mapper_int_g_term_collector(op_name, k=0): raise ValueError(f"unknown operator name: {op_name}") from pytential.symbolic.mappers import flatten - assert expr_only_intgs == flatten(expected_expr) + assert flatten(expr_only_intgs) == flatten(expected_expr) # }}}