Skip to content

Commit

Permalink
SETUP: tell ruff to check long lines too.
Browse files Browse the repository at this point in the history
  • Loading branch information
mhvk committed Nov 18, 2024
1 parent 62e6559 commit d3eb4c6
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 17 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ extend-select = [
"ARG", # flake8-unused-arguments
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"EXE", # flake8-executable
"E", # errors
"G", # flake8-logging-format
"I", # isort
"ICN", # flake8-import-conventions
Expand Down
44 changes: 37 additions & 7 deletions quantity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def __pow__(self, exp, mod=None):
if exp.imag == 0:
exp = exp.real

return self.__class__(self.value ** exp, self.unit ** exp)
return self.__class__(self.value**exp, self.unit**exp)

def _operate(self, other, op, units_helper):
if not isinstance(other, QUANTITY_OR_NUMBER) and (
not has_array_namespace(other)
not has_array_namespace(other)
):
# HACK: unit should take care of this!
if not isinstance(other, u.UnitBase):
Expand Down Expand Up @@ -130,7 +130,9 @@ def __array_ufunc__(self, function, method, *inputs, **kwargs):
if out is not None:
# If pre-allocated output is used, check it is suitable.
# This also returns array view, to ensure we don't loop back.
out_values = self._check_output(out, unit if len(out) > 1 else [unit], function=function)
out_values = self._check_output(
out, unit if len(out) > 1 else [unit], function=function
)
# Ensure output argument remains a tuple.
kwargs["out"] = out_values
if len(out) == 1:
Expand All @@ -145,7 +147,10 @@ def __array_ufunc__(self, function, method, *inputs, **kwargs):
kwargs["initial"] = self._to_own_unit(kwargs["initial"], unit=unit)

input_values = [get_value_and_unit(in_)[0] for in_ in inputs]
input_values = [v if conv is None else conv(v) for (v, conv) in zip(input_values, converters, strict=True)]
input_values = [
v if conv is None else conv(v)
for (v, conv) in zip(input_values, converters, strict=True)
]
try:
xp = self.value.__array_namespace__()
except AttributeError:
Expand Down Expand Up @@ -191,7 +196,10 @@ def _check_output(self, output, unit, function=None):
if unit == self._dimensionless_unit:
return output

msg = "Cannot store output with unit '{}'{} in {} instance. Use {} instance instead."
msg = (
"Cannot store output with unit '{}'{} in {} instance. "
"Use {} instance instead."
)
raise u.UnitTypeError(
msg.format(
unit,
Expand Down Expand Up @@ -239,17 +247,36 @@ def _result_as_quantity(self, result, unit, out):

DEFERRED = "dtype", "device", "ndim", "shape", "size"
SAME_UNIT = "mT", "T", "__abs__", "__neg__", "__pos__", "__getitem__"
OPERATORS = "__add__", "__floordiv__", "__matmul__", "__mod__", "__mul__", "__pow__", "__sub__", "__truediv__"
OPERATORS = (
"__add__",
"__floordiv__",
"__matmul__",
"__mod__",
"__mul__",
"__pow__",
"__sub__",
"__truediv__",
)
COMPARISONS = "__eq__", "__ge__", "__gt__", "__le__", "__lt__", "__ne__"
DEFER_DIMENSIONLESS = "__complex__", "__float__", "__int__"
NOT_IMPLEMENTED = "__and__", "__bool__", "__index__", "__invert__", "__lshift__", "__or__", "__rshift__", "__xor__"
NOT_IMPLEMENTED = (
"__and__",
"__bool__",
"__index__",
"__invert__",
"__lshift__",
"__or__",
"__rshift__",
"__xor__",
)
TODO = "__dlpack__", "__dlpack_device__", "to_device"


def make_deferred(attr):
# Use array_api_compat getter if available (size, device), since
# some array formats provide inconsistent implementations.
getter = getattr(array_api_compat, attr, operator.attrgetter(attr))

def deferred(self):
return getter(self.value)

Expand Down Expand Up @@ -279,11 +306,13 @@ def same_unit(self):
def make__op__(op, helper):
if op.startswith("__r"):
op_func = getattr(operator, op.replace("__r", "__"))

def wrapped_helper(u1, u2):
return helper(op_func, u2, u1)

else:
op_func = getattr(operator, op)

def wrapped_helper(u1, u2):
return helper(op_func, u1, u2)

Expand Down Expand Up @@ -341,5 +370,6 @@ def defer_dimensionless(self):

return defer_dimensionless


for attr in DEFER_DIMENSIONLESS:
setattr(Quantity, attr, make_defer_dimensionless(attr))
2 changes: 1 addition & 1 deletion quantity/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import annotations

import array_api_compat
from functools import cached_property

import array_api_compat
import astropy.units as u
import numpy as np
from astropy.utils.decorators import classproperty
Expand Down
15 changes: 8 additions & 7 deletions quantity/tests/test_quantity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Test the Quantity class and related."""

from __future__ import annotations

import array_api_compat
Expand All @@ -14,7 +15,6 @@


class QuantityCreationTests:

@classmethod
def setup_class(cls):
cls.value = cls.xp.asarray(np.arange(10.0))
Expand Down Expand Up @@ -65,7 +65,6 @@ def test_getitem(self):


class QuantityOperationTests:

xp = np

@classmethod
Expand Down Expand Up @@ -164,14 +163,16 @@ def test_power(self):
q = self.q1**3
self.assert_equal(q, 1489.355, "m^3", decimal=3)

@pytest.mark.parametrize("exponent", [2, 2.0, np.uint64(2), np.int32(2), np.float32(2)])
@pytest.mark.parametrize(
"exponent", [2, 2.0, np.uint64(2), np.int32(2), np.float32(2)]
)
def test_non_standard_power(self, exponent):
q = self.q1 ** exponent
q = self.q1**exponent
self.assert_equal(q, 130.4164, "m^2", decimal=5)

def test_quantity_exponent(self):
exponent = Quantity(self.xp.asarray(2.0), u.one)
q = self.q1 ** exponent
q = self.q1**exponent
self.assert_equal(q, 130.4164, "m^2", decimal=5)

def test_matrix_multiplication(self):
Expand Down Expand Up @@ -215,7 +216,7 @@ def test_abs(self):
self.assert_equal(abs_q, self.xp.abs(a), "m / s")

def test_incompatible_units(self):
"""When trying to add or subtract units that aren't compatible, throw an error"""
"""Raise when trying to add or subtract incompatible units."""
q2 = Quantity(self.q2.value, unit=u.second)
with pytest.raises(u.UnitsError):
self.q1 + q2
Expand Down Expand Up @@ -298,7 +299,7 @@ def test_numeric_converters(self):
assert int(q1) == 1

q2 = Quantity(self.xp.asarray(1.25), u.km / u.m)
assert float(q2) == 1250.
assert float(q2) == 1250.0
assert int(q2) == 1250

def test_numeric_converters_fail_on_array(self):
Expand Down

0 comments on commit d3eb4c6

Please sign in to comment.