diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 06207611e0..b96d5b2ea4 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -89,7 +89,7 @@ def ethosu_compile_spec( self.compiler_flags.append(extra_flags) base_tosa_version = "TOSA-0.80.0+BI" - if "U55" in config: + if "u55" in config: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 0a88bc45aa..c133ce8003 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -5,4 +5,9 @@ # pyre-unsafe -from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa +from . import ( # noqa + mean_dim_support, + right_shift_support, + tosa_supported_operators, + var_correction_support, +) diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py new file mode 100644 index 0000000000..ee8d5965a1 --- /dev/null +++ b/backends/arm/operator_support/right_shift_support.py @@ -0,0 +1,35 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_tosa_support_check +class RightShiftSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.__rshift__.Scalar] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + + # TODO MLETORCH-525 Remove warning + if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: + logging.warning(f"{node.target} may introduce one-off errors.") + return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 988765990d..a5c2dd8dc5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -27,6 +27,7 @@ op_reciprocal, op_relu, op_repeat, + op_rshift, op_rsqrt, op_select, op_sigmoid, diff --git a/backends/arm/operators/op_rshift.py b/backends/arm/operators/op_rshift.py new file mode 100644 index 0000000000..94b3f8b86d --- /dev/null +++ b/backends/arm/operators/op_rshift.py @@ -0,0 +1,99 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_specification import Tosa_0_80 +from executorch.backends.arm.tosa_utils import tosa_shape +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class RshiftVisitor(NodeVisitor): + target = "aten.__rshift__.Scalar" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_shape = inputs[0].shape + input_0_rank = len(input_shape) + shift_expanded_shape = [1] * input_0_rank + dtype = node.meta["val"].dtype + attr = ts.TosaSerializerAttribute() + cast_input = False + cast_output = False + round = False + cast_type = dtype + if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: + # U55 only supports INT32 and round == True + # TODO MLETORCH-525 Emulate round == False with different decomposition + if dtype != torch.int32: + cast_input = True + cast_output = True + cast_type = torch.int32 + round = True + attr.ArithmeticRightShiftAttribute(round=round) + + if cast_input: + # input needs to be casted to INT32 + shift_input = tosa_graph.addIntermediate( + shape=tosa_shape(input_shape, inputs[0].dim_order), + dtype=map_dtype(cast_type), + ) + tosa_graph.addOperator( + TosaOp.Op().CAST, + [inputs[0].name], + [shift_input.name], + None, + ) + else: + shift_input = inputs[0] + if cast_output: + # add intermediate tensor for right shift + shift = tosa_graph.addIntermediate( + shape=tosa_shape(input_shape, inputs[0].dim_order), + dtype=map_dtype(cast_type), + ) + else: + shift = output + # create tensor with same rank as inputs[0] + data = torch.full( + shift_expanded_shape, fill_value=inputs[1].number, dtype=dtype + ) + shift_const_name = node.name + "-shift_const" + tosa_graph.addConst( + shift_expanded_shape, + map_dtype(cast_type), + data.detach().numpy(), + shift_const_name, + ) + # add right shift operator + tosa_graph.addOperator( + TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, + [shift_input.name, shift_const_name], + [shift.name], + attr, + ) + if cast_output: + # cast output to original output dtype + tosa_graph.addOperator( + TosaOp.Op().CAST, + [shift.name], + [output.name], + None, + ) diff --git a/backends/arm/test/ops/test_rshift.py b/backends/arm/test/ops/test_rshift.py new file mode 100644 index 0000000000..dfbd0fdb3e --- /dev/null +++ b/backends/arm/test/ops/test_rshift.py @@ -0,0 +1,90 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +class TestRshift(unittest.TestCase): + """ + Tests arithmetic right shift + """ + + class Rshift(torch.nn.Module): + test_data = [ + ((torch.IntTensor(5, 5), 2),), + ((torch.IntTensor(1, 2, 3, 4), 3),), + ((torch.ShortTensor(1, 5, 3, 4), 5),), + ((torch.CharTensor(10, 12, 3, 4), 1),), + ] + + def forward(self, x: torch.Tensor, shift: int): + return x >> shift + + def _test_rshift_tosa_MI(self, test_data): + ( + ArmTester( + self.Rshift(), + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .to_edge_transform_and_lower() + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_rshift_tosa_BI(self, test_data): + ( + ArmTester( + self.Rshift(), + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + # TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO + # .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_rshift_ethosu_BI(self, test_data, compile_spec): + return ( + ArmTester( + self.Rshift(), + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + ) + + @parameterized.expand(Rshift.test_data) + def test_rshift_tosa_MI(self, test_data): + self._test_rshift_tosa_MI(test_data) + + @parameterized.expand(Rshift.test_data) + def test_rshift_tosa_BI(self, test_data): + self._test_rshift_tosa_BI(test_data) + + # TODO Enable FVP testing + @parameterized.expand(Rshift.test_data) + def test_rshift_u55_BI(self, test_data): + compile_spec = common.get_u55_compile_spec() + self._test_rshift_ethosu_BI(test_data, compile_spec) + + # TODO Enable FVP testing + @parameterized.expand(Rshift.test_data) + def test_rshift_u85_BI(self, test_data): + compile_spec = common.get_u85_compile_spec() + self._test_rshift_ethosu_BI(test_data, compile_spec)