Skip to content

Commit

Permalink
gem: attach dtype to every node
Browse files Browse the repository at this point in the history
Co-authored-by: David A. Ham <[email protected]>
  • Loading branch information
ksagiyam and dham committed Nov 19, 2024
1 parent 8c1c4c0 commit 2ad805b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 50 deletions.
111 changes: 73 additions & 38 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ def __call__(self, *args, **kwargs):
if not hasattr(obj, 'free_indices'):
obj.free_indices = unique(chain(*[c.free_indices
for c in obj.children]))
# Set dtype if not set already.
if not hasattr(obj, 'dtype'):
obj.dtype = obj.inherit_dtype_from_children(obj.children)

return obj


class Node(NodeBase, metaclass=NodeMeta):
"""Abstract GEM node class."""

__slots__ = ('free_indices',)
__slots__ = ('free_indices', 'dtype')

def is_equal(self, other):
"""Common subexpression eliminating equality predicate.
Expand Down Expand Up @@ -153,16 +156,46 @@ def __mod__(self, other):
def __rmod__(self, other):
return as_gem_uint(other).__mod__(self)

@staticmethod
def inherit_dtype_from_children(children):
if any(c.dtype is None for c in children):
# Set dtype = None will let _assign_dtype()
# assign the default dtype for this node later.
return
else:
return numpy.result_type(*(c.dtype for c in children))


class Terminal(Node):
"""Abstract class for terminal GEM nodes."""

__slots__ = ()
__slots__ = ('_dtype',)

children = ()

is_equal = NodeBase.is_equal

@property
def dtype(self):
"""dtype of the node.
We only need to set dtype (or _dtype) on terminal nodes, and
other nodes inherit dtype from their children.
Currently dtype is significant only for nodes under index DAGs
(DAGs underneath `VariableIndex`s representing indices), and
`VariableIndex` checks if the dtype of the node that it wraps is
of uint_type. _assign_dtype() will then assign uint_type to those nodes.
dtype can be `None` otherwise, and _assign_dtype() will assign
the default dtype to those nodes.
"""
if hasattr(self, '_dtype'):
return self._dtype
else:
raise AttributeError(f"Must set _dtype on terminal node, {type(self)}")


class Scalar(Node):
"""Abstract class for scalar-valued GEM nodes."""
Expand All @@ -181,6 +214,7 @@ class Failure(Terminal):
def __init__(self, shape, exception):
self.shape = shape
self.exception = exception
self._dtype = None


class Constant(Terminal):
Expand All @@ -190,35 +224,36 @@ class Constant(Terminal):
- array: numpy array of values
- value: float or complex value (scalars only)
"""
__slots__ = ('dtype',)
__back__ = ('dtype',)
pass


class Zero(Constant):
"""Symbolic zero tensor"""

__slots__ = ('shape',)
__front__ = ('shape',)
__back__ = ('dtype',)

def __init__(self, shape=(), dtype=float):
def __init__(self, shape=(), dtype=None):
self.shape = shape
self.dtype = dtype
self._dtype = dtype

@property
def value(self):
assert not self.shape
return numpy.array(0, dtype=self.dtype).item()
return numpy.array(0, dtype=self.dtype or float).item()


class Identity(Constant):
"""Identity matrix"""

__slots__ = ('dim',)
__front__ = ('dim',)
__back__ = ('dtype',)

def __init__(self, dim, dtype=float):
def __init__(self, dim, dtype=None):
self.dim = dim
self.dtype = dtype
self._dtype = dtype

@property
def shape(self):
Expand All @@ -234,6 +269,7 @@ class Literal(Constant):

__slots__ = ('array',)
__front__ = ('array',)
__back__ = ('dtype',)

def __new__(cls, array, dtype=None):
array = asarray(array)
Expand All @@ -245,14 +281,12 @@ def __init__(self, array, dtype=None):
# Assume float or complex.
try:
self.array = array.astype(float, casting="safe")
self.dtype = float
except TypeError:
self.array = array.astype(complex)
self.dtype = complex
else:
# Can be int, etc.
self.array = array.astype(dtype)
self.dtype = dtype
self._dtype = self.array.dtype

def is_equal(self, other):
if type(self) is not type(other):
Expand All @@ -277,13 +311,14 @@ def shape(self):
class Variable(Terminal):
"""Symbolic variable tensor"""

__slots__ = ('name', 'shape', 'dtype')
__front__ = ('name', 'shape', 'dtype')
__slots__ = ('name', 'shape')
__front__ = ('name', 'shape')
__back__ = ('dtype',)

def __init__(self, name, shape, dtype=None):
self.name = name
self.shape = shape
self.dtype = dtype
self._dtype = dtype


class Sum(Scalar):
Expand All @@ -300,8 +335,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value + b.value, dtype=dtype)
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b]))

self = super(Sum, cls).__new__(cls)
self.children = a, b
Expand All @@ -325,8 +359,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value * b.value, dtype=dtype)
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b]))

self = super(Product, cls).__new__(cls)
self.children = a, b
Expand All @@ -350,8 +383,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value / b.value, dtype=dtype)
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b]))

self = super(Division, cls).__new__(cls)
self.children = a, b
Expand All @@ -364,18 +396,17 @@ class FloorDiv(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# TODO: Attach dtype property to Node and check that
# numpy.result_dtype(a.dtype, b.dtype) is uint type.
# dtype is currently attached only to {Constant, Variable}.
dtype = Node.inherit_dtype_from_children([a, b])
if dtype is not uint_type:
raise ValueError(f"dtype ({dtype}) is not unit_type ({uint_type})")
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero(dtype=a.dtype)
return Zero(dtype=dtype)
if isinstance(b, Constant) and b.value == 1:
return a
if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value // b.value, dtype=dtype)
self = super(FloorDiv, cls).__new__(cls)
self.children = a, b
Expand All @@ -388,18 +419,17 @@ class Remainder(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# TODO: Attach dtype property to Node and check that
# numpy.result_dtype(a.dtype, b.dtype) is uint type.
# dtype is currently attached only to {Constant, Variable}.
dtype = Node.inherit_dtype_from_children([a, b])
if dtype is not uint_type:
raise ValueError(f"dtype ({dtype}) is not uint_type ({uint_type})")
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero(dtype=a.dtype)
return Zero(dtype=dtype)
if isinstance(b, Constant) and b.value == 1:
return Zero(dtype=b.dtype)
return Zero(dtype=dtype)
if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value % b.value, dtype=dtype)
self = super(Remainder, cls).__new__(cls)
self.children = a, b
Expand All @@ -412,18 +442,16 @@ class Power(Scalar):
def __new__(cls, base, exponent):
assert not base.shape
assert not exponent.shape
dtype = Node.inherit_dtype_from_children([base, exponent])

# Constant folding
if isinstance(base, Zero):
dtype = numpy.result_type(base.dtype, exponent.dtype)
if isinstance(exponent, Zero):
raise ValueError("cannot solve 0^0")
return Zero(dtype=dtype)
elif isinstance(exponent, Zero):
dtype = numpy.result_type(base.dtype, exponent.dtype)
return Literal(1, dtype=dtype)
elif isinstance(base, Constant) and isinstance(exponent, Constant):
dtype = numpy.result_type(base.dtype, exponent.dtype)
return Literal(base.value ** exponent.value, dtype=dtype)

self = super(Power, cls).__new__(cls)
Expand Down Expand Up @@ -483,6 +511,7 @@ def __init__(self, op, a, b):

self.operator = op
self.children = a, b
self.dtype = None # Do not inherit dtype from children.


class LogicalNot(Scalar):
Expand Down Expand Up @@ -529,6 +558,7 @@ def __new__(cls, condition, then, else_):
self = super(Conditional, cls).__new__(cls)
self.children = condition, then, else_
self.shape = then.shape
self.dtype = Node.inherit_dtype_from_children([then, else_])
return self


Expand Down Expand Up @@ -591,6 +621,8 @@ class VariableIndex(IndexBase):
def __init__(self, expression):
assert isinstance(expression, Node)
assert not expression.shape
if expression.dtype is not uint_type:
raise ValueError(f"expression.dtype ({expression.dtype}) is not uint_type ({uint_type})")
self.expression = expression

def __eq__(self, other):
Expand Down Expand Up @@ -846,6 +878,7 @@ class ListTensor(Node):
def __new__(cls, array):
array = asarray(array)
assert numpy.prod(array.shape)
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
Expand All @@ -861,7 +894,7 @@ def __new__(cls, array):

# Constant folding
if all(isinstance(elem, Constant) for elem in array.flat):
return Literal(numpy.vectorize(attrgetter('value'))(array))
return Literal(numpy.vectorize(attrgetter('value'))(array), dtype=dtype)

self = super(ListTensor, cls).__new__(cls)
self.array = array
Expand Down Expand Up @@ -907,9 +940,9 @@ class Concatenate(Node):
__slots__ = ('children',)

def __new__(cls, *children):
dtype = Node.inherit_dtype_from_children(children)
if all(isinstance(child, Zero) for child in children):
size = int(sum(numpy.prod(child.shape, dtype=int) for child in children))
dtype = numpy.result_type(*(child.dtype for child in children))
return Zero((size,), dtype=dtype)

self = super(Concatenate, cls).__new__(cls)
Expand All @@ -924,8 +957,9 @@ def shape(self):
class Delta(Scalar, Terminal):
__slots__ = ('i', 'j')
__front__ = ('i', 'j')
__back__ = ('dtype',)

def __new__(cls, i, j):
def __new__(cls, i, j, dtype=None):
assert isinstance(i, IndexBase)
assert isinstance(j, IndexBase)

Expand All @@ -948,6 +982,7 @@ def __new__(cls, i, j):
elif isinstance(index, VariableIndex):
raise NotImplementedError("Can not make Delta with VariableIndex")
self.free_indices = tuple(unique(free_indices))
self._dtype = dtype
return self


Expand Down
2 changes: 1 addition & 1 deletion tests/test_pickle_gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@pytest.mark.parametrize('protocol', range(3))
def test_pickle_gem(protocol):
f = gem.VariableIndex(gem.Indexed(gem.Variable('facet', (2,)), (1,)))
f = gem.VariableIndex(gem.Indexed(gem.Variable('facet', (2,), dtype=gem.uint_type), (1,)))
q = gem.Index()
r = gem.Index()
_1 = gem.Indexed(gem.Literal(numpy.random.rand(3, 6, 8)), (f, q, r))
Expand Down
8 changes: 4 additions & 4 deletions tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def __init__(self, scalar_type, interior_facet=False):

# Cell orientation
if self.interior_facet:
cell_orientations = gem.Variable("cell_orientations", (2,))
cell_orientations = gem.Variable("cell_orientations", (2,), dtype=gem.uint_type)
self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),
gem.Indexed(cell_orientations, (1,)))
else:
cell_orientations = gem.Variable("cell_orientations", (1,))
cell_orientations = gem.Variable("cell_orientations", (1,), dtype=gem.uint_type)
self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),)

def _coefficient(self, coefficient, name):
Expand Down Expand Up @@ -257,12 +257,12 @@ def __init__(self, integral_data_info, scalar_type,

# Facet number
if integral_type in ['exterior_facet', 'exterior_facet_vert']:
facet = gem.Variable('facet', (1,))
facet = gem.Variable('facet', (1,), dtype=gem.uint_type)
self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))}
facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type)
self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))}
elif integral_type in ['interior_facet', 'interior_facet_vert']:
facet = gem.Variable('facet', (2,))
facet = gem.Variable('facet', (2,), dtype=gem.uint_type)
self._entity_number = {
'+': gem.VariableIndex(gem.Indexed(facet, (0,))),
'-': gem.VariableIndex(gem.Indexed(facet, (1,)))
Expand Down
Loading

0 comments on commit 2ad805b

Please sign in to comment.