Skip to content

Commit

Permalink
three_interpolate_npu_init
Browse files Browse the repository at this point in the history
  • Loading branch information
lihao7212148 committed Oct 23, 2023
1 parent df2dadb commit dfabcec
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 4 deletions.
56 changes: 52 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,72 @@ using namespace std;
void three_interpolate_forward_npu(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out) {
auto point_c_trans = points.transpose(1, 2);
auto originDtype = points.scalar_type();
at::Tensor pointsCast = points;
at::Tensor weightCast = weight;
at::Tensor outCast = out;

TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");

if (originDtype == at::ScalarType::Half) {
pointsCast = points.to(at::kFloat);
weightCast = weight.to(at::kFloat);
outCast = out.to(at::kFloat);
}

auto point_c_trans = pointsCast.transpose(1, 2);

OpCommand cmd;
cmd.Name("ThreeInterpolate")
.Input(point_c_trans)
.Input(idx)
.Input(weight)
.Output(out)
.Input(weightCast)
.Output(outCast)
.Run();

auto output = out.view({b, n, c}).transpose(1, 2);
auto output = outCast.view({b, n, c}).transpose(1, 2);
auto res = NpuUtils::format_contiguous(output);
out.copy_(res);
}

void three_interpolate_backward_npu(int b, int c, int n, int m,
const Tensor grad_out, const Tensor idx,
const Tensor weight, Tensor grad_points) {
auto originDtype = grad_out.scalar_type();
at::Tensor gradOutCast = grad_out;
at::Tensor weightCast = weight;
at::Tensor gradPointsCast = grad_points;

TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_backward ascend only support fp32 and fp16.");

if (originDtype == at::ScalarType::Half) {
gradOutCast = grad_out.to(at::kFloat);
weightCast = weight.to(at::kFloat);
gradPointsCast = grad_points.to(at::kFloat);
}

OpCommand cmd;
cmd.Name("ThreeInterpolateBackward")
.Input(gradOutCast)
.Input(idx)
.Input(weightCast)
.Output(gradPointsCast)
.Attr("m", m)
.Run();
}

void three_interpolate_forward_impl(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out);

void three_interpolate_backward_impl(int b, int c, int n, int m,
const Tensor grad_out, const Tensor idx,
const Tensor weight, Tensor grad_points);

REGISTER_NPU_IMPL(three_interpolate_forward_impl,
three_interpolate_forward_npu);

REGISTER_NPU_IMPL(three_interpolate_backward_impl,
three_interpolate_backward_npu);
87 changes: 87 additions & 0 deletions tests/test_ops/test_three_interpolate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import numpy as np

from mmcv.ops import three_interpolate
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
Expand Down Expand Up @@ -95,3 +96,89 @@ def test_three_interpolate(dtype, device):
device=device)

assert torch.allclose(output, expected_output, 1e-3, 1e-4)


def three_interpolate_forward_gloden(features, idx, weight):
bs, cs, ms = features.shape
ns = idx.shape[1]

dtype = features.dtype
if dtype == np.float16:
features = features.astype(np.float32)
weight = weight.astype(np.float32)

output = np.zeros((bs, cs, ns), dtype=np.float)
for b in range(bs):
for c in range(cs):
for n in range(ns):
output[b][c][n] = \
features[b][c][idx[b][n][0]] * weight[b][n][0] \
+ features[b][c][idx[b][n][1]] * weight[b][n][1] \
+ features[b][c][idx[b][n][2]] * weight[b][n][2]
return output


def three_interpolate_backward_gloden(grad_output, idx, weight, features):
bs, cs, ns = grad_output.shape
ms = features.shape[2]

dtype = features.dtype
if dtype == np.float16:
features = features.astype(np.float32)
weight = weight.astype(np.float32)

grad_point = np.zeros((bs, cs, ms), dtype=np.float)
for b in range(bs):
for c in range(cs):
for n in range(ns):
grad_point[b][c][idx[b][n][0]] = \
grad_point[b][c][idx[b][n][0]] + \
grad_output[b][c][n] * weight[b][n][0]
grad_point[b][c][idx[b][n][1]] = \
grad_point[b][c][idx[b][n][1]] + \
grad_output[b][c][n] * weight[b][n][1]
grad_point[b][c][idx[b][n][2]] = \
grad_point[b][c][idx[b][n][2]] + \
grad_output[b][c][n] * weight[b][n][2]
return grad_point


def torch_type_trans(dtype):
if dtype == torch.half:
return np.float16
elif dtype == torch.float:
return np.float32
else:
return np.float64


@pytest.mark.parametrize('dtype', [torch.half, torch.float])
@pytest.mark.parametrize('device', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10),
(20, 21, 13, 4), (2, 10, 2, 18),
(10, 602, 910, 200), (600, 100, 300, 101)])
def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
bs = shape[0]
cs = shape[1]
ms = shape[2]
ns = shape[3]

features = np.random.uniform(-10.0, 10.0,
(bs, cs, ms)).astype(torch_type_trans(dtype))
idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32)
weight = np.random.uniform(-10.0,
10.0 (bs, ns,
3)).astype(torch_type_trans(dtype))

features_npu = torch.tensor(features, dtype=dtype).to(device)
idx_npu = torch.tensor(idx, dtype=torch.int32).to(device)
weight_npu = torch.tensor(weight, dtype=dtype).to(device)

expected_output = three_interpolate_forward_gloden(features, idx, weight)
output = three_interpolate(features_npu, idx_npu, weight_npu)
assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)

0 comments on commit dfabcec

Please sign in to comment.