-
Notifications
You must be signed in to change notification settings - Fork 301
/
TensorView.h
110 lines (93 loc) · 3.2 KB
/
TensorView.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
//-----------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//-----------------------------------------------------------------------------
#pragma once
// Provides a non-owning typed view of an N-dimensional strided tensor over a linear array of elements.
template <typename T, size_t N>
class TensorViewND
{
public:
using Extents = TensorExtents<N>;
TensorViewND() = default;
TensorViewND(dml::Span<T> data, Extents sizes, Extents strides)
: m_data(data)
, m_sizes(sizes)
, m_strides(strides)
{
#if _DEBUG
// Ensure the buffer is large enough
uint32_t offsetOfLastElement = 0;
for (size_t i = 0; i < N; ++i)
{
assert(m_sizes[i] > 0); // Zero size is invalid
offsetOfLastElement += (m_sizes[i] - 1) * m_strides[i];
}
assert(static_cast<uint32_t>(data.size()) > offsetOfLastElement);
#endif
}
TensorViewND(dml::Span<T> data, Extents sizes)
: TensorViewND(data, sizes, TensorUtil::GetPackedStrides(sizes))
{
}
const Extents& Sizes() const
{
return m_sizes;
}
const Extents& Strides() const
{
return m_strides;
}
uint32_t ElementCount() const
{
return TensorUtil::GetElementCount(m_sizes);
}
// Access an element by e.g. NCHW coordinate using an Extents of the appropriate dimension. Example:
// float x = view(NchwExtents(1, 2, 3, 4));
T& operator()(Extents indices) const
{
return At(indices);
}
// Access an element by e.g. NCHW coordinate. The meaning of the indices depends on how many are provided. Examples:
// float a = view(1, 2); // HW: H=1, W=2
// float b = view(1, 2, 3); // CHW: C=1, H=2, W=3
// float c = view(1, 2, 3, 4); // NCHW: N=1, C=2, H=3, W=4
// float d = view(1, 2, 3, 4, 5); // NCDHW: N=1, C=2, D=3, H=4, W=5
//
// If you provide fewer indices than the dimensionality of the view, the missing indices are treated as 0.
template <typename... Ts>
T& operator()(Ts&&... indices) const
{
static_assert(sizeof...(indices) >= 2, "Too few indices provided.");
static_assert(sizeof...(indices) <= N, "Too many indices provided.");
return At(Extents(std::forward<Ts>(indices)...));
}
// Access an element by linear index.
T& operator[](uint32_t elementIndex) const
{
Extents indices = TensorUtil::GetElementIndices(elementIndex, m_sizes);
return At(indices);
}
protected:
T& At(Extents indices) const
{
// Bounds check
for (size_t i = 0; i < N; ++i)
{
assert(indices[i] < m_sizes[i]);
}
uint32_t elementOffset = TensorUtil::GetElementOffset(indices, m_strides);
assert(elementOffset < static_cast<uint32_t>(m_data.size()));
return m_data[elementOffset];
}
private:
dml::Span<T> m_data = {};
Extents m_sizes = {};
Extents m_strides = {};
};
// 4D tensor is the default
template <typename T>
using TensorView = TensorViewND<T, 4>;
template <typename T>
using TensorView5D = TensorViewND<T, 5>;