diff --git a/gem/gem.py b/gem/gem.py index c37e4352..ded6bf65 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -56,6 +56,9 @@ def __call__(self, *args, **kwargs): if not hasattr(obj, 'free_indices'): obj.free_indices = unique(chain(*[c.free_indices for c in obj.children])) + # Set dtype if not set already. + if not hasattr(obj, 'dtype'): + obj.dtype = obj.inherit_dtype_from_children(obj.children) return obj @@ -63,7 +66,7 @@ def __call__(self, *args, **kwargs): class Node(NodeBase, metaclass=NodeMeta): """Abstract GEM node class.""" - __slots__ = ('free_indices',) + __slots__ = ('free_indices', 'dtype') def is_equal(self, other): """Common subexpression eliminating equality predicate. @@ -153,16 +156,46 @@ def __mod__(self, other): def __rmod__(self, other): return as_gem_uint(other).__mod__(self) + @staticmethod + def inherit_dtype_from_children(children): + if any(c.dtype is None for c in children): + # Set dtype = None will let _assign_dtype() + # assign the default dtype for this node later. + return + else: + return numpy.result_type(*(c.dtype for c in children)) + class Terminal(Node): """Abstract class for terminal GEM nodes.""" - __slots__ = () + __slots__ = ('_dtype',) children = () is_equal = NodeBase.is_equal + @property + def dtype(self): + """dtype of the node. + + We only need to set dtype (or _dtype) on terminal nodes, and + other nodes inherit dtype from their children. + + Currently dtype is significant only for nodes under index DAGs + (DAGs underneath `VariableIndex`s representing indices), and + `VariableIndex` checks if the dtype of the node that it wraps is + of uint_type. _assign_dtype() will then assign uint_type to those nodes. + + dtype can be `None` otherwise, and _assign_dtype() will assign + the default dtype to those nodes. + + """ + if hasattr(self, '_dtype'): + return self._dtype + else: + raise AttributeError(f"Must set _dtype on terminal node, {type(self)}") + class Scalar(Node): """Abstract class for scalar-valued GEM nodes.""" @@ -181,6 +214,7 @@ class Failure(Terminal): def __init__(self, shape, exception): self.shape = shape self.exception = exception + self._dtype = None class Constant(Terminal): @@ -190,8 +224,7 @@ class Constant(Terminal): - array: numpy array of values - value: float or complex value (scalars only) """ - __slots__ = ('dtype',) - __back__ = ('dtype',) + pass class Zero(Constant): @@ -199,15 +232,16 @@ class Zero(Constant): __slots__ = ('shape',) __front__ = ('shape',) + __back__ = ('dtype',) - def __init__(self, shape=(), dtype=float): + def __init__(self, shape=(), dtype=None): self.shape = shape - self.dtype = dtype + self._dtype = dtype @property def value(self): assert not self.shape - return numpy.array(0, dtype=self.dtype).item() + return numpy.array(0, dtype=self.dtype or float).item() class Identity(Constant): @@ -215,10 +249,11 @@ class Identity(Constant): __slots__ = ('dim',) __front__ = ('dim',) + __back__ = ('dtype',) - def __init__(self, dim, dtype=float): + def __init__(self, dim, dtype=None): self.dim = dim - self.dtype = dtype + self._dtype = dtype @property def shape(self): @@ -234,6 +269,7 @@ class Literal(Constant): __slots__ = ('array',) __front__ = ('array',) + __back__ = ('dtype',) def __new__(cls, array, dtype=None): array = asarray(array) @@ -245,14 +281,12 @@ def __init__(self, array, dtype=None): # Assume float or complex. try: self.array = array.astype(float, casting="safe") - self.dtype = float except TypeError: self.array = array.astype(complex) - self.dtype = complex else: # Can be int, etc. self.array = array.astype(dtype) - self.dtype = dtype + self._dtype = self.array.dtype def is_equal(self, other): if type(self) is not type(other): @@ -277,13 +311,14 @@ def shape(self): class Variable(Terminal): """Symbolic variable tensor""" - __slots__ = ('name', 'shape', 'dtype') - __front__ = ('name', 'shape', 'dtype') + __slots__ = ('name', 'shape') + __front__ = ('name', 'shape') + __back__ = ('dtype',) def __init__(self, name, shape, dtype=None): self.name = name self.shape = shape - self.dtype = dtype + self._dtype = dtype class Sum(Scalar): @@ -300,8 +335,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - dtype = numpy.result_type(a.dtype, b.dtype) - return Literal(a.value + b.value, dtype=dtype) + return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b])) self = super(Sum, cls).__new__(cls) self.children = a, b @@ -325,8 +359,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - dtype = numpy.result_type(a.dtype, b.dtype) - return Literal(a.value * b.value, dtype=dtype) + return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b])) self = super(Product, cls).__new__(cls) self.children = a, b @@ -350,8 +383,7 @@ def __new__(cls, a, b): return a if isinstance(a, Constant) and isinstance(b, Constant): - dtype = numpy.result_type(a.dtype, b.dtype) - return Literal(a.value / b.value, dtype=dtype) + return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b])) self = super(Division, cls).__new__(cls) self.children = a, b @@ -364,18 +396,17 @@ class FloorDiv(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - # TODO: Attach dtype property to Node and check that - # numpy.result_dtype(a.dtype, b.dtype) is uint type. - # dtype is currently attached only to {Constant, Variable}. + dtype = Node.inherit_dtype_from_children([a, b]) + if dtype is not uint_type: + raise ValueError(f"dtype ({dtype}) is not unit_type ({uint_type})") # Constant folding if isinstance(b, Zero): raise ValueError("division by zero") if isinstance(a, Zero): - return Zero(dtype=a.dtype) + return Zero(dtype=dtype) if isinstance(b, Constant) and b.value == 1: return a if isinstance(a, Constant) and isinstance(b, Constant): - dtype = numpy.result_type(a.dtype, b.dtype) return Literal(a.value // b.value, dtype=dtype) self = super(FloorDiv, cls).__new__(cls) self.children = a, b @@ -388,18 +419,17 @@ class Remainder(Scalar): def __new__(cls, a, b): assert not a.shape assert not b.shape - # TODO: Attach dtype property to Node and check that - # numpy.result_dtype(a.dtype, b.dtype) is uint type. - # dtype is currently attached only to {Constant, Variable}. + dtype = Node.inherit_dtype_from_children([a, b]) + if dtype is not uint_type: + raise ValueError(f"dtype ({dtype}) is not uint_type ({uint_type})") # Constant folding if isinstance(b, Zero): raise ValueError("division by zero") if isinstance(a, Zero): - return Zero(dtype=a.dtype) + return Zero(dtype=dtype) if isinstance(b, Constant) and b.value == 1: - return Zero(dtype=b.dtype) + return Zero(dtype=dtype) if isinstance(a, Constant) and isinstance(b, Constant): - dtype = numpy.result_type(a.dtype, b.dtype) return Literal(a.value % b.value, dtype=dtype) self = super(Remainder, cls).__new__(cls) self.children = a, b @@ -412,18 +442,16 @@ class Power(Scalar): def __new__(cls, base, exponent): assert not base.shape assert not exponent.shape + dtype = Node.inherit_dtype_from_children([base, exponent]) # Constant folding if isinstance(base, Zero): - dtype = numpy.result_type(base.dtype, exponent.dtype) if isinstance(exponent, Zero): raise ValueError("cannot solve 0^0") return Zero(dtype=dtype) elif isinstance(exponent, Zero): - dtype = numpy.result_type(base.dtype, exponent.dtype) return Literal(1, dtype=dtype) elif isinstance(base, Constant) and isinstance(exponent, Constant): - dtype = numpy.result_type(base.dtype, exponent.dtype) return Literal(base.value ** exponent.value, dtype=dtype) self = super(Power, cls).__new__(cls) @@ -483,6 +511,7 @@ def __init__(self, op, a, b): self.operator = op self.children = a, b + self.dtype = None # Do not inherit dtype from children. class LogicalNot(Scalar): @@ -529,6 +558,7 @@ def __new__(cls, condition, then, else_): self = super(Conditional, cls).__new__(cls) self.children = condition, then, else_ self.shape = then.shape + self.dtype = Node.inherit_dtype_from_children([then, else_]) return self @@ -591,6 +621,8 @@ class VariableIndex(IndexBase): def __init__(self, expression): assert isinstance(expression, Node) assert not expression.shape + if expression.dtype is not uint_type: + raise ValueError(f"expression.dtype ({expression.dtype}) is not uint_type ({uint_type})") self.expression = expression def __eq__(self, other): @@ -846,6 +878,7 @@ class ListTensor(Node): def __new__(cls, array): array = asarray(array) assert numpy.prod(array.shape) + dtype = Node.inherit_dtype_from_children(tuple(array.flat)) # Handle children with shape child_shape = array.flat[0].shape @@ -861,7 +894,7 @@ def __new__(cls, array): # Constant folding if all(isinstance(elem, Constant) for elem in array.flat): - return Literal(numpy.vectorize(attrgetter('value'))(array)) + return Literal(numpy.vectorize(attrgetter('value'))(array), dtype=dtype) self = super(ListTensor, cls).__new__(cls) self.array = array @@ -907,9 +940,9 @@ class Concatenate(Node): __slots__ = ('children',) def __new__(cls, *children): + dtype = Node.inherit_dtype_from_children(children) if all(isinstance(child, Zero) for child in children): size = int(sum(numpy.prod(child.shape, dtype=int) for child in children)) - dtype = numpy.result_type(*(child.dtype for child in children)) return Zero((size,), dtype=dtype) self = super(Concatenate, cls).__new__(cls) @@ -924,8 +957,9 @@ def shape(self): class Delta(Scalar, Terminal): __slots__ = ('i', 'j') __front__ = ('i', 'j') + __back__ = ('dtype',) - def __new__(cls, i, j): + def __new__(cls, i, j, dtype=None): assert isinstance(i, IndexBase) assert isinstance(j, IndexBase) @@ -948,6 +982,7 @@ def __new__(cls, i, j): elif isinstance(index, VariableIndex): raise NotImplementedError("Can not make Delta with VariableIndex") self.free_indices = tuple(unique(free_indices)) + self._dtype = dtype return self diff --git a/tests/test_pickle_gem.py b/tests/test_pickle_gem.py index 73e39cac..beb101f9 100644 --- a/tests/test_pickle_gem.py +++ b/tests/test_pickle_gem.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize('protocol', range(3)) def test_pickle_gem(protocol): - f = gem.VariableIndex(gem.Indexed(gem.Variable('facet', (2,)), (1,))) + f = gem.VariableIndex(gem.Indexed(gem.Variable('facet', (2,), dtype=gem.uint_type), (1,))) q = gem.Index() r = gem.Index() _1 = gem.Indexed(gem.Literal(numpy.random.rand(3, 6, 8)), (f, q, r)) diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index cc35a7c5..0969854c 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -85,11 +85,11 @@ def __init__(self, scalar_type, interior_facet=False): # Cell orientation if self.interior_facet: - cell_orientations = gem.Variable("cell_orientations", (2,)) + cell_orientations = gem.Variable("cell_orientations", (2,), dtype=gem.uint_type) self._cell_orientations = (gem.Indexed(cell_orientations, (0,)), gem.Indexed(cell_orientations, (1,))) else: - cell_orientations = gem.Variable("cell_orientations", (1,)) + cell_orientations = gem.Variable("cell_orientations", (1,), dtype=gem.uint_type) self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),) def _coefficient(self, coefficient, name): @@ -257,12 +257,12 @@ def __init__(self, integral_data_info, scalar_type, # Facet number if integral_type in ['exterior_facet', 'exterior_facet_vert']: - facet = gem.Variable('facet', (1,)) + facet = gem.Variable('facet', (1,), dtype=gem.uint_type) self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))} facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type) self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))} elif integral_type in ['interior_facet', 'interior_facet_vert']: - facet = gem.Variable('facet', (2,)) + facet = gem.Variable('facet', (2,), dtype=gem.uint_type) self._entity_number = { '+': gem.VariableIndex(gem.Indexed(facet, (0,))), '-': gem.VariableIndex(gem.Indexed(facet, (1,))) diff --git a/tsfc/loopy.py b/tsfc/loopy.py index d19da423..08efaede 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -47,7 +47,7 @@ def _assign_dtype(expression, self): @_assign_dtype.register(gem.Terminal) def _assign_dtype_terminal(expression, self): - return {self.scalar_type} + return {expression.dtype or self.scalar_type} @_assign_dtype.register(gem.Variable) @@ -59,7 +59,7 @@ def _assign_dtype_variable(expression, self): @_assign_dtype.register(gem.Identity) @_assign_dtype.register(gem.Delta) def _assign_dtype_real(expression, self): - return {self.real_type} + return {expression.dtype or self.real_type} @_assign_dtype.register(gem.Literal) @@ -70,15 +70,15 @@ def _assign_dtype_identity(expression, self): @_assign_dtype.register(gem.Power) def _assign_dtype_power(expression, self): # Conservative - return {self.scalar_type} + return {expression.dtype or self.scalar_type} @_assign_dtype.register(gem.MathFunction) def _assign_dtype_mathfunction(expression, self): if expression.name in {"abs", "real", "imag"}: - return {self.real_type} + return {expression.dtype or self.real_type} elif expression.name == "sqrt": - return {self.scalar_type} + return {expression.dtype or self.scalar_type} else: return set.union(*map(self, expression.children)) @@ -87,7 +87,7 @@ def _assign_dtype_mathfunction(expression, self): @_assign_dtype.register(gem.MaxValue) def _assign_dtype_minmax(expression, self): # UFL did correctness checking - return {self.real_type} + return {expression.dtype or self.real_type} @_assign_dtype.register(gem.Conditional) @@ -100,7 +100,7 @@ def _assign_dtype_conditional(expression, self): @_assign_dtype.register(gem.LogicalAnd) @_assign_dtype.register(gem.LogicalOr) def _assign_dtype_logical(expression, self): - return {numpy.int8} + return {expression.dtype or numpy.int8} def assign_dtypes(expressions, scalar_type):