Skip to content

Commit

Permalink
prepare for Tensor merge
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 26, 2023
1 parent 6d8d4c4 commit 94292c2
Show file tree
Hide file tree
Showing 14 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions nn/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def concat(
if allow_broadcast:
opts["allow_broadcast"] = True
else:
dims = sources[0][0].shape - {sources[0][1]}
dims = sources[0][0].dims_set - {sources[0][1]}
for src, dim in sources:
assert src.shape - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
out_dim = sum(d for _, d in sources)
res = nn.make_layer(
{"class": "concat", "from": sources, "out_dim": out_dim, **opts},
Expand Down
2 changes: 1 addition & 1 deletion nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __deepcopy__(self, memo):
return self

@property
def shape(self) -> Set[Dim]:
def dims_set(self) -> Set[Dim]:
"""
:return: shape, as a set of dims.
The order must not play a role
Expand Down
2 changes: 1 addition & 1 deletion nn/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(

def __call__(self, source: nn.Tensor, *, in_spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
"""forward"""
assert self.in_dim in source.shape
assert self.in_dim in source.dims_set
in_spatial_dims = [in_spatial_dim, self.in_dim]
in_dim = self._dummy_in_dim
x = nn.expand_dim(source, dim=in_dim)
Expand Down
2 changes: 1 addition & 1 deletion nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __call__(
out_spatial_dims: Optional[Sequence[nn.Dim]] = None,
) -> Tuple[nn.Tensor, Sequence[nn.Dim]]:
for in_spatial_dim in in_spatial_dims:
if in_spatial_dim not in source.shape:
if in_spatial_dim not in source.dims_set:
raise ValueError(f"{self}: source {source} does not have spatial dim {in_spatial_dim}")
out_spatial_dims = out_spatial_dims or self.make_out_spatial_dims(in_spatial_dims)
layer_dict = {
Expand Down
2 changes: 1 addition & 1 deletion nn/encoder/blstm_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, in_dim: nn.Dim, dim: nn.Dim = nn.FeatureDim("feat", 32), *, f
self.out_dim = self._final_extra_spatial_dim * dim

def __call__(self, x: nn.Tensor, *, spatial_dim: nn.Dim) -> nn.Tensor:
assert self.in_dim in x.shape
assert self.in_dim in x.dims_set
batch_dims = x.batch_dims_ordered((self.in_dim, spatial_dim))
extra_spatial_dim = self.in_dim
x = nn.expand_dim(x, dim=self._dummy_feat_dim)
Expand Down
2 changes: 1 addition & 1 deletion nn/hybrid_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __call__(
raise TypeError(f"unsupported encoder type {type(self.encoder)}")
out_embed = self.out_projection(encoder_output)
if train:
assert out_spatial_dim in targets.shape
assert out_spatial_dim in targets.dims_set
ce_loss = nn.sparse_softmax_cross_entropy_with_logits(logits=out_embed, targets=targets, axis=self.out_dim)
ce_loss.mark_as_loss("ce")
return nn.log_softmax(out_embed, axis=self.out_dim), None
2 changes: 1 addition & 1 deletion nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, in_dim: nn.Dim, out_dim: nn.Dim, *, with_bias=True):
def __call__(self, source: nn.Tensor) -> nn.Tensor:
if not isinstance(source, nn.Tensor):
raise TypeError(f"{self}: source must be a Tensor but got {type(source)}")
if self.in_dim not in source.shape and self.in_dim != source.data.sparse_dim:
if self.in_dim not in source.dims_set and self.in_dim != source.data.sparse_dim:
raise ValueError(f"{self}: input {source} does not have in_dim {self.in_dim}")
out = nn.dot(source, self.weight, reduce=self.in_dim)
if self.with_bias:
Expand Down
2 changes: 1 addition & 1 deletion nn/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def unstack(self, source: nn.Tensor, *, name: Optional[str] = None) -> nn.Tensor
from . import rec_unstack

assert self._has_given_axis, "%s: unstack() requires a given axis" % self
assert self.axis in source.shape
assert self.axis in source.dims_set
res = rec_unstack(source, axis=self.axis, name=name)
self.unstacked_refs.append(res)
return res
Expand Down
4 changes: 2 additions & 2 deletions nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __call__(self, source: nn.Tensor) -> nn.Tensor:
# which is potentially the use of a fused op,
# and maybe reordering of dims.
# https://github.com/rwth-i6/returnn_common/issues/89
spatial_dims = source.shape - {nn.batch_dim, self.in_dim}
assert len(spatial_dims) == len(source.shape) - 2
spatial_dims = source.dims_set - {nn.batch_dim, self.in_dim}
assert len(spatial_dims) == len(source.dims_set) - 2
if any(d.dimension is None for d in spatial_dims): # any dynamic spatial dim
if self.use_mask is None:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion nn/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __call__(
:param direction: 1 for forward direction, -1 for backward direction
:return: out, out_state. out_state is the new or last state.
"""
assert self.in_dim in source.shape
assert self.in_dim in source.dims_set
rec_layer_dict = {
"class": "rec",
"from": source,
Expand Down
2 changes: 1 addition & 1 deletion nn/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __call__(
:return: memory (encoder output), out logits, out labels (only when doing search), final state
"""
assert (
self.model_dim in source.shape
self.model_dim in source.dims_set
), f"{self}: Input {source} feature dimension is not matching Transformer model dimension {self.model_dim}"
memory = self.encoder(source, axis=source_spatial_axis)
search = None
Expand Down
2 changes: 1 addition & 1 deletion nn/utils/dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def dim_value(dim: nn.Dim) -> Union[nn.Tensor, int]:
if dim.dimension is not None:
return dim.dimension
length_ = nn.length(dim)
if not length_.shape:
if not length_.dims_set:
return length_
return nn.reduce(length_, mode="max", axis=length_.shape_ordered)

Expand Down
2 changes: 1 addition & 1 deletion nn/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def interpolate_bilinear(
img_dtype = image.dtype

assert all([q.dtype == query_dtype for q in query_points])
assert all([q.shape == query_points[0].shape for q in query_points]) # not really necessary but reasonable
assert all([q.dims_set == query_points[0].dims_set for q in query_points]) # not really necessary but reasonable

alphas = []
floors = []
Expand Down
2 changes: 1 addition & 1 deletion nn/utils/label_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def label_smoothing(prob: nn.Tensor, smoothing: Union[nn.Tensor, float], *, axis
assert prob.data.sparse_dim == axis
return nn.smooth_one_hot(prob, label_prob=1.0 - smoothing)
else:
assert axis in prob.shape
assert axis in prob.dims_set
# Make it consistent to the sparse case.
# Value of 1.0 should result in (1 - smoothing).
# Value of 0.0 should result in smoothing / (dim - 1).
Expand Down

0 comments on commit 94292c2

Please sign in to comment.