From dfabcec7cb89c128cef7c365ea353553a6a63b43 Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Mon, 23 Oct 2023 09:59:29 +0800 Subject: [PATCH 1/7] three_interpolate_npu_init --- .../pytorch/npu/three_interpolate_npu.cpp | 56 +++++++++++- tests/test_ops/test_three_interpolate.py | 87 +++++++++++++++++++ 2 files changed, 139 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 0f1b14e7dc..fa07d86bc9 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -6,24 +6,72 @@ using namespace std; void three_interpolate_forward_npu(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out) { - auto point_c_trans = points.transpose(1, 2); + auto originDtype = points.scalar_type(); + at::Tensor pointsCast = points; + at::Tensor weightCast = weight; + at::Tensor outCast = out; + + TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), + "three_interpolate_forward ascend only support fp32 and fp16."); + + if (originDtype == at::ScalarType::Half) { + pointsCast = points.to(at::kFloat); + weightCast = weight.to(at::kFloat); + outCast = out.to(at::kFloat); + } + + auto point_c_trans = pointsCast.transpose(1, 2); OpCommand cmd; cmd.Name("ThreeInterpolate") .Input(point_c_trans) .Input(idx) - .Input(weight) - .Output(out) + .Input(weightCast) + .Output(outCast) .Run(); - auto output = out.view({b, n, c}).transpose(1, 2); + auto output = outCast.view({b, n, c}).transpose(1, 2); auto res = NpuUtils::format_contiguous(output); out.copy_(res); } +void three_interpolate_backward_npu(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points) { + auto originDtype = grad_out.scalar_type(); + at::Tensor gradOutCast = grad_out; + at::Tensor weightCast = weight; + at::Tensor gradPointsCast = grad_points; + + TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), + "three_interpolate_backward ascend only support fp32 and fp16."); + + if (originDtype == at::ScalarType::Half) { + gradOutCast = grad_out.to(at::kFloat); + weightCast = weight.to(at::kFloat); + gradPointsCast = grad_points.to(at::kFloat); + } + + OpCommand cmd; + cmd.Name("ThreeInterpolateBackward") + .Input(gradOutCast) + .Input(idx) + .Input(weightCast) + .Output(gradPointsCast) + .Attr("m", m) + .Run(); +} + void three_interpolate_forward_impl(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out); +void three_interpolate_backward_impl(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points); + REGISTER_NPU_IMPL(three_interpolate_forward_impl, three_interpolate_forward_npu); + +REGISTER_NPU_IMPL(three_interpolate_backward_impl, + three_interpolate_backward_npu); diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index d27a795ecf..253d25de4b 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +import numpy as np from mmcv.ops import three_interpolate from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE @@ -95,3 +96,89 @@ def test_three_interpolate(dtype, device): device=device) assert torch.allclose(output, expected_output, 1e-3, 1e-4) + + +def three_interpolate_forward_gloden(features, idx, weight): + bs, cs, ms = features.shape + ns = idx.shape[1] + + dtype = features.dtype + if dtype == np.float16: + features = features.astype(np.float32) + weight = weight.astype(np.float32) + + output = np.zeros((bs, cs, ns), dtype=np.float) + for b in range(bs): + for c in range(cs): + for n in range(ns): + output[b][c][n] = \ + features[b][c][idx[b][n][0]] * weight[b][n][0] \ + + features[b][c][idx[b][n][1]] * weight[b][n][1] \ + + features[b][c][idx[b][n][2]] * weight[b][n][2] + return output + + +def three_interpolate_backward_gloden(grad_output, idx, weight, features): + bs, cs, ns = grad_output.shape + ms = features.shape[2] + + dtype = features.dtype + if dtype == np.float16: + features = features.astype(np.float32) + weight = weight.astype(np.float32) + + grad_point = np.zeros((bs, cs, ms), dtype=np.float) + for b in range(bs): + for c in range(cs): + for n in range(ns): + grad_point[b][c][idx[b][n][0]] = \ + grad_point[b][c][idx[b][n][0]] + \ + grad_output[b][c][n] * weight[b][n][0] + grad_point[b][c][idx[b][n][1]] = \ + grad_point[b][c][idx[b][n][1]] + \ + grad_output[b][c][n] * weight[b][n][1] + grad_point[b][c][idx[b][n][2]] = \ + grad_point[b][c][idx[b][n][2]] + \ + grad_output[b][c][n] * weight[b][n][2] + return grad_point + + +def torch_type_trans(dtype): + if dtype == torch.half: + return np.float16 + elif dtype == torch.float: + return np.float32 + else: + return np.float64 + + +@pytest.mark.parametrize('dtype', [torch.half, torch.float]) +@pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10), + (20, 21, 13, 4), (2, 10, 2, 18), + (10, 602, 910, 200), (600, 100, 300, 101)]) +def test_three_interpolate_npu_dynamic_shape(dtype, device, shape): + bs = shape[0] + cs = shape[1] + ms = shape[2] + ns = shape[3] + + features = np.random.uniform(-10.0, 10.0, + (bs, cs, ms)).astype(torch_type_trans(dtype)) + idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32) + weight = np.random.uniform(-10.0, + 10.0 (bs, ns, + 3)).astype(torch_type_trans(dtype)) + + features_npu = torch.tensor(features, dtype=dtype).to(device) + idx_npu = torch.tensor(idx, dtype=torch.int32).to(device) + weight_npu = torch.tensor(weight, dtype=dtype).to(device) + + expected_output = three_interpolate_forward_gloden(features, idx, weight) + output = three_interpolate(features_npu, idx_npu, weight_npu) + assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4) From ee78e3f5cb67546da9af4a75b939fccfeeb2339a Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Mon, 23 Oct 2023 19:26:29 +0800 Subject: [PATCH 2/7] linterror --- tests/test_ops/test_three_interpolate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index 253d25de4b..0153be0738 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -170,7 +170,7 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape): features = np.random.uniform(-10.0, 10.0, (bs, cs, ms)).astype(torch_type_trans(dtype)) - idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32) + idx = np.random.randint(0, ms, size=(bs, ns, 3), dtype=np.int32) weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)).astype(torch_type_trans(dtype)) From 5d917594aa9648b80f41cc85967837a2e663abf1 Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Mon, 23 Oct 2023 19:29:22 +0800 Subject: [PATCH 3/7] linterr fixed --- tests/test_ops/test_three_interpolate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index 0153be0738..0975baae32 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np import pytest import torch -import numpy as np from mmcv.ops import three_interpolate from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE From b8d7839ff538bcd8ddbf5ccbc47393e15babc2b7 Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Mon, 23 Oct 2023 19:59:48 +0800 Subject: [PATCH 4/7] linterror --- .../pytorch/npu/three_interpolate_npu.cpp | 32 ++++--------------- tests/test_ops/test_three_interpolate.py | 12 ++----- 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index fa07d86bc9..92b967a7e0 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -7,30 +7,20 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out) { auto originDtype = points.scalar_type(); - at::Tensor pointsCast = points; - at::Tensor weightCast = weight; - at::Tensor outCast = out; - TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_forward ascend only support fp32 and fp16."); - if (originDtype == at::ScalarType::Half) { - pointsCast = points.to(at::kFloat); - weightCast = weight.to(at::kFloat); - outCast = out.to(at::kFloat); - } - auto point_c_trans = pointsCast.transpose(1, 2); OpCommand cmd; cmd.Name("ThreeInterpolate") .Input(point_c_trans) .Input(idx) - .Input(weightCast) - .Output(outCast) + .Input(weight) + .Output(out) .Run(); - auto output = outCast.view({b, n, c}).transpose(1, 2); + auto output = out.view({b, n, c}).transpose(1, 2); auto res = NpuUtils::format_contiguous(output); out.copy_(res); } @@ -39,25 +29,15 @@ void three_interpolate_backward_npu(int b, int c, int n, int m, const Tensor grad_out, const Tensor idx, const Tensor weight, Tensor grad_points) { auto originDtype = grad_out.scalar_type(); - at::Tensor gradOutCast = grad_out; - at::Tensor weightCast = weight; - at::Tensor gradPointsCast = grad_points; - TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_backward ascend only support fp32 and fp16."); - if (originDtype == at::ScalarType::Half) { - gradOutCast = grad_out.to(at::kFloat); - weightCast = weight.to(at::kFloat); - gradPointsCast = grad_points.to(at::kFloat); - } - OpCommand cmd; cmd.Name("ThreeInterpolateBackward") - .Input(gradOutCast) + .Input(grad_out) .Input(idx) - .Input(weightCast) - .Output(gradPointsCast) + .Input(weight) + .Output(grad_points) .Attr("m", m) .Run(); } diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index 0975baae32..3de6ddd769 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -103,11 +103,7 @@ def three_interpolate_forward_gloden(features, idx, weight): ns = idx.shape[1] dtype = features.dtype - if dtype == np.float16: - features = features.astype(np.float32) - weight = weight.astype(np.float32) - - output = np.zeros((bs, cs, ns), dtype=np.float) + output = np.zeros((bs, cs, ns), dtype=dtype) for b in range(bs): for c in range(cs): for n in range(ns): @@ -123,11 +119,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features): ms = features.shape[2] dtype = features.dtype - if dtype == np.float16: - features = features.astype(np.float32) - weight = weight.astype(np.float32) - - grad_point = np.zeros((bs, cs, ms), dtype=np.float) + grad_point = np.zeros((bs, cs, ms), dtype=dtype) for b in range(bs): for c in range(cs): for n in range(ns): From 78b86c9fb907aed4bd22cb46186d930c00c7856a Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Mon, 23 Oct 2023 20:37:46 +0800 Subject: [PATCH 5/7] fix --- mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 92b967a7e0..fc5990081b 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -10,7 +10,7 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_forward ascend only support fp32 and fp16."); - auto point_c_trans = pointsCast.transpose(1, 2); + auto point_c_trans = points.transpose(1, 2); OpCommand cmd; cmd.Name("ThreeInterpolate") From 405d1489d86997cf056c97a4d881287e68641b17 Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Wed, 25 Oct 2023 10:54:25 +0800 Subject: [PATCH 6/7] fix --- mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index fc5990081b..9eee537a4e 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -32,14 +32,7 @@ void three_interpolate_backward_npu(int b, int c, int n, int m, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_backward ascend only support fp32 and fp16."); - OpCommand cmd; - cmd.Name("ThreeInterpolateBackward") - .Input(grad_out) - .Input(idx) - .Input(weight) - .Output(grad_points) - .Attr("m", m) - .Run(); + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, &grad_out, &idx, &weight, &grad_points, m); } void three_interpolate_forward_impl(int b, int c, int m, int n, From 4fc8bad6aefddac13afd7017e7d2da590cdd417f Mon Sep 17 00:00:00 2001 From: 18608119613 Date: Wed, 29 Nov 2023 14:34:49 +0800 Subject: [PATCH 7/7] add backward npu adapter --- mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 9eee537a4e..6832dc51f6 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -1,4 +1,6 @@ #include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" using namespace NPU_NAME_SPACE; using namespace std; @@ -32,7 +34,14 @@ void three_interpolate_backward_npu(int b, int c, int n, int m, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_backward ascend only support fp32 and fp16."); - EXEC_NPU_CMD(aclnnThreeInterpolateBackward, &grad_out, &idx, &weight, &grad_points, m); + auto grad_x = at::unsqueeze(grad_out, 3); + auto grad_y = at::unsqueeze(grad_points, 3); + + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y); + + auto output = at::squeeze(grad_y, 3); + auto res = NpuUtils::format_contiguous(output); + grad_points.copy_(res); } void three_interpolate_forward_impl(int b, int c, int m, int n,