Skip to content

Commit

Permalink
feat: Add ts converter support for aten::all.dim (#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise authored Apr 19, 2023
1 parent 6f7627f commit 1d78f43
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 24 deletions.
76 changes: 54 additions & 22 deletions core/conversion/converters/impl/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,36 @@ namespace converters {
namespace impl {
namespace {

nvinfer1::ITensor* anyDimImplementation(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* in_tensor,
int dim,
bool keepdim) {
auto in_dims = in_tensor->getDimensions();
LOG_DEBUG("Dim to reduce (original): " << dim);
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
LOG_DEBUG("Dim to reduce (converted): " << dim);

uint32_t axis_mask = 1 << dim;
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
LOG_DEBUG("Keep dims: " << keepdim);

// Reduce does not work on bool inputs
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
}
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);

TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);

sum_layer->setName(util::node_info(n).c_str());
auto out_tensor =
castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
return out_tensor;
}

auto reduce_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
Expand Down Expand Up @@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED =
{"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in_tensor = args[0].ITensorOrFreeze(ctx);
auto in_dims = in_tensor->getDimensions();
auto dim = args[1].unwrapToInt();
LOG_DEBUG("Dim to reduce (original): " << dim);
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
LOG_DEBUG("Dim to reduce (converted): " << dim);

uint32_t axis_mask = 1 << dim;
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));

auto keepdim = args[2].unwrapToBool();
LOG_DEBUG("Keep dims: " << keepdim);

// Reduce does not work on bool inputs
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
in_tensor =
castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
}
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);

TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);

sum_layer->setName(util::node_info(n).c_str());
auto out_tensor = castITensor(
ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim);
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// use Not(Any(Not(input))) to calculate all without a direct all reduction
auto in_tensor = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto keepdim = args[2].unwrapToBool();
if (in_tensor->getType() != nvinfer1::DataType::kBOOL) {
// unary not layer only supports bool inputs
in_tensor = castITensor(
ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str());
}
auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n);
not_input_layer->setName((util::node_info(n) + "_not_in").c_str());
auto not_in = not_input_layer->getOutput(0);
auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim);
auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n);
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}});
} // namespace
} // namespace impl
Expand Down
53 changes: 51 additions & 2 deletions tests/core/conversion/converters/test_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) {
return (%5))IR";
}

void test_body(const std::string& graph, at::Tensor& in) {
void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

Expand All @@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) {

in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
std::vector<at::Tensor> trt_results;
if (dynamic) {
trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in});
} else {
trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
}
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}
} // namespace
Expand Down Expand Up @@ -344,6 +349,50 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
test_body(graph, in);
}

TEST(Converters, ATenAllDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=-1]()
%3 : bool = prim::Constant[value=0]()
%5 : Tensor = aten::all(%0, %1, %3)
return (%5))IR";
auto in = at::randint(0, 2, {64, 2}, at::kCUDA);
test_body(graph, in);
}

TEST(Converters, ATenAllDimKeepDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=1]()
%5 : Tensor = aten::all(%0, %1, %3)
return (%5))IR";
auto in = at::randint(-2, 2, {2, 32}, at::kCUDA).to(torch::kBool);
test_body(graph, in);
}

TEST(Converters, ATenAllDimAllTrueConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=1]()
%3 : bool = prim::Constant[value=0]()
%5 : Tensor = aten::all(%0, %1, %3)
return (%5))IR";
auto in = at::ones({2, 32}, at::kCUDA);
test_body(graph, in);
}

TEST(Converters, ATenAllDimDynamicConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=-1]()
%3 : bool = prim::Constant[value=0]()
%5 : Tensor = aten::all(%0, %1, %3)
return (%5))IR";
auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf);
test_body(graph, in, true);
}

TEST(Converters, UnpackVarLowersCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down

0 comments on commit 1d78f43

Please sign in to comment.