Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm backend: Add initial support for right shift #7006

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def ethosu_compile_spec(
self.compiler_flags.append(extra_flags)

base_tosa_version = "TOSA-0.80.0+BI"
if "U55" in config:
if "u55" in config:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@

# pyre-unsafe

from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa
from . import ( # noqa
mean_dim_support,
right_shift_support,
tosa_supported_operators,
var_correction_support,
)
35 changes: 35 additions & 0 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


@register_tosa_support_check
class RightShiftSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.__rshift__.Scalar]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# TODO MLETORCH-525 Remove warning
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
logging.warning(f"{node.target} may introduce one-off errors.")
return True
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
op_reciprocal,
op_relu,
op_repeat,
op_rshift,
op_rsqrt,
op_select,
op_sigmoid,
Expand Down
99 changes: 99 additions & 0 deletions backends/arm/operators/op_rshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_specification import Tosa_0_80
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class RshiftVisitor(NodeVisitor):
target = "aten.__rshift__.Scalar"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input_shape = inputs[0].shape
input_0_rank = len(input_shape)
shift_expanded_shape = [1] * input_0_rank
dtype = node.meta["val"].dtype
attr = ts.TosaSerializerAttribute()
cast_input = False
cast_output = False
round = False
cast_type = dtype
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
# U55 only supports INT32 and round == True
# TODO MLETORCH-525 Emulate round == False with different decomposition
if dtype != torch.int32:
cast_input = True
cast_output = True
cast_type = torch.int32
round = True
attr.ArithmeticRightShiftAttribute(round=round)

if cast_input:
# input needs to be casted to INT32
shift_input = tosa_graph.addIntermediate(
shape=tosa_shape(input_shape, inputs[0].dim_order),
dtype=map_dtype(cast_type),
)
tosa_graph.addOperator(
TosaOp.Op().CAST,
[inputs[0].name],
[shift_input.name],
None,
)
else:
shift_input = inputs[0]
if cast_output:
# add intermediate tensor for right shift
shift = tosa_graph.addIntermediate(
shape=tosa_shape(input_shape, inputs[0].dim_order),
dtype=map_dtype(cast_type),
)
else:
shift = output
# create tensor with same rank as inputs[0]
data = torch.full(
shift_expanded_shape, fill_value=inputs[1].number, dtype=dtype
)
shift_const_name = node.name + "-shift_const"
tosa_graph.addConst(
shift_expanded_shape,
map_dtype(cast_type),
data.detach().numpy(),
shift_const_name,
)
# add right shift operator
tosa_graph.addOperator(
TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
[shift_input.name, shift_const_name],
[shift.name],
attr,
)
if cast_output:
# cast output to original output dtype
tosa_graph.addOperator(
TosaOp.Op().CAST,
[shift.name],
[output.name],
None,
)
90 changes: 90 additions & 0 deletions backends/arm/test/ops/test_rshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized


class TestRshift(unittest.TestCase):
"""
Tests arithmetic right shift
"""

class Rshift(torch.nn.Module):
test_data = [
((torch.IntTensor(5, 5), 2),),
((torch.IntTensor(1, 2, 3, 4), 3),),
((torch.ShortTensor(1, 5, 3, 4), 5),),
((torch.CharTensor(10, 12, 3, 4), 1),),
]

def forward(self, x: torch.Tensor, shift: int):
return x >> shift

def _test_rshift_tosa_MI(self, test_data):
(
ArmTester(
self.Rshift(),
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_rshift_tosa_BI(self, test_data):
(
ArmTester(
self.Rshift(),
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
# TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO
# .run_method_and_compare_outputs(inputs=test_data)
)

def _test_rshift_ethosu_BI(self, test_data, compile_spec):
return (
ArmTester(
self.Rshift(),
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
)

@parameterized.expand(Rshift.test_data)
def test_rshift_tosa_MI(self, test_data):
self._test_rshift_tosa_MI(test_data)

@parameterized.expand(Rshift.test_data)
def test_rshift_tosa_BI(self, test_data):
self._test_rshift_tosa_BI(test_data)

# TODO Enable FVP testing
@parameterized.expand(Rshift.test_data)
def test_rshift_u55_BI(self, test_data):
compile_spec = common.get_u55_compile_spec()
self._test_rshift_ethosu_BI(test_data, compile_spec)

# TODO Enable FVP testing
@parameterized.expand(Rshift.test_data)
def test_rshift_u85_BI(self, test_data):
compile_spec = common.get_u85_compile_spec()
self._test_rshift_ethosu_BI(test_data, compile_spec)
Loading