Skip to content

Commit

Permalink
Support ragged aux_labels in k2.compose (#686)
Browse files Browse the repository at this point in the history
* Support ragged aux_labels in k2.compose

* Wrap ArgMaxPerSublist.

* Output new2old map for k2.ragged.unique_sequences

* Throw an exception instead of calling `abort()` to print Python
stacktrace.

* Fixes after review.

* More fixes after review.
  • Loading branch information
csukuangfj authored Mar 18, 2021
1 parent 171ddc1 commit b33eab3
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 67 deletions.
9 changes: 7 additions & 2 deletions k2/csrc/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cstdlib>
#include <mutex> // NOLINT
#include <sstream>
#include <stdexcept>
#include <string>

#include "k2/csrc/macros.h"
Expand Down Expand Up @@ -128,7 +129,7 @@ class Logger {
}
}

K2_CUDA_HOSTDEV ~Logger() {
K2_CUDA_HOSTDEV ~Logger() noexcept(false) {
printf("\n");
if (level_ == FATAL) {
#if defined(__CUDA_ARCH__)
Expand All @@ -143,7 +144,11 @@ class Logger {
printf("\n\n%s\n", stack_trace.c_str());
}
fflush(nullptr);
abort();
// abort();
//
// NOTE: abort() will terminate the program immediately without
// printing the Python stack backtrace.
throw std::runtime_error("Some bad things happed.");
#endif
}
}
Expand Down
8 changes: 6 additions & 2 deletions k2/csrc/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2219,12 +2219,13 @@ Array1<T> ComputeHash(Ragged<int32_t> &src) {
}

Ragged<int32_t> UniqueSequences(Ragged<int32_t> &src,
Ragged<int32_t> *num_repeats /*=nullptr*/) {
Ragged<int32_t> *num_repeats /*=nullptr*/,
Array1<int32_t> *new2old_indexes /*=nullptr*/) {
ContextPtr &c = src.Context();
if (src.NumAxes() == 2) {
// Put 'fake' layer at front, process, then remove.
Ragged<int32_t> temp = Unsqueeze(src, 0);
return UniqueSequences(temp, num_repeats).RemoveAxis(0);
return UniqueSequences(temp, num_repeats, new2old_indexes).RemoveAxis(0);
}
Array1<int64_t> hashes = ComputeHash<int64_t>(src);
int32_t hashes_dim = hashes.Dim();
Expand Down Expand Up @@ -2272,6 +2273,9 @@ Ragged<int32_t> UniqueSequences(Ragged<int32_t> &src,
*num_repeats = Ragged<int32_t>(GetLayer(ans.shape, ans.NumAxes() - 3),
num_repeats_array);
}
if (new2old_indexes != nullptr) {
*new2old_indexes = std::move(new2unsorted);
}
return ans;
}

Expand Down
15 changes: 14 additions & 1 deletion k2/csrc/ragged_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,14 @@ Array1<T> ComputeHash(Ragged<int32_t> &src);
repeats (i.e., multiplicity) of each output sequence.
The caller does not need to pre-allocate it. It is
allocated inside the function.
@param [out] new2old_indexes
If not NULL, on return new2old_indexes[i] contains
the original input sublist for the i-th output sublist.
If `src` has 2 axes, this array contains `src_idx0`;
if `src` has 3 axes, this array contains `src_idx01`.
CAUTION: For repeated sublists, only one of them is kept.
The choice of which one to keep is **deterministic** and
is an implementation detail.
@return Returns a tensor with the same number of axes as `src` and
possibly fewer elements due to removing repeated sequences on the
Expand All @@ -1356,9 +1364,14 @@ Array1<T> ComputeHash(Ragged<int32_t> &src);
be present in the output, as it relies on a hash and ignores collisions.
If several sequences have the same hash, only one of them is kept, even
if the actual content in the sequence is different.
CAUTION: Even if there are no repeated sequences, the output may be different
from `src`. That is, `new2old_indexes` may NOT be an identity map even if
nothing was removed.
*/
Ragged<int32_t> UniqueSequences(Ragged<int32_t> &src,
Ragged<int32_t> *num_repeats = nullptr);
Ragged<int32_t> *num_repeats = nullptr,
Array1<int32_t> *new2old_indexes = nullptr);

/* Compute exclusive sum per sub-list.
*
Expand Down
39 changes: 34 additions & 5 deletions k2/python/csrc/torch/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* See LICENSE for clarification regarding multiple authors
*/

#include <tuple>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -208,17 +209,27 @@ static void PybindGetLayer(py::module &m) {
static void PybindUniqueSequences(py::module &m) {
m.def(
"unique_sequences",
[](Ragged<int32_t> &src, bool need_num_repeats = true)
-> std::pair<Ragged<int32_t>, torch::optional<Ragged<int32_t>>> {
[](Ragged<int32_t> &src, bool need_num_repeats = true,
bool need_new2old_indexes = false)
-> std::tuple<Ragged<int32_t>, torch::optional<Ragged<int32_t>>,
torch::optional<torch::Tensor>> {
Ragged<int32_t> num_repeats;
Array1<int32_t> new2old_indexes;
Ragged<int32_t> ans =
UniqueSequences(src, need_num_repeats ? &num_repeats : nullptr);
UniqueSequences(src, need_num_repeats ? &num_repeats : nullptr,
need_new2old_indexes ? &new2old_indexes : nullptr);

torch::optional<Ragged<int32_t>> num_repeats_tensor;
if (need_num_repeats) num_repeats_tensor = num_repeats;
return std::make_pair(ans, num_repeats_tensor);

torch::optional<torch::Tensor> new2old_indexes_tensor;
if (need_new2old_indexes)
new2old_indexes_tensor = ToTensor(new2old_indexes);

return std::make_tuple(ans, num_repeats_tensor, new2old_indexes_tensor);
},
py::arg("src"), py::arg("need_num_repeats"));
py::arg("src"), py::arg("need_num_repeats") = true,
py::arg("need_new2old_indexes") = false);
}

static void PybindIndex(py::module &m) {
Expand Down Expand Up @@ -263,6 +274,23 @@ static void PybindRegularRaggedShape(py::module &m) {
py::arg("dim0"), py::arg("dim1"));
}

template <typename T>
static void PybindArgMaxPerSublist(py::module &m) {
m.def(
"argmax_per_sublist",
[](Ragged<T> &src, T initial_value) -> torch::Tensor {
int32_t last_axis = src.NumAxes() - 1;
const Array1<int32_t> &row_splits_array = src.RowSplits(last_axis);
int32_t num_rows = row_splits_array.Dim() - 1;

Array1<int32_t> indexes(src.Context(), num_rows);
ArgMaxPerSublist(src, initial_value, &indexes);

return ToTensor(indexes);
},
py::arg("src"), py::arg("initial_value"));
}

} // namespace k2

void PybindRaggedOps(py::module &m) {
Expand All @@ -282,4 +310,5 @@ void PybindRaggedOps(py::module &m) {
PybindUniqueSequences(m);
PybindIndex(m);
PybindRegularRaggedShape(m);
PybindArgMaxPerSublist<float>(m);
}
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch # noqa
from _k2 import RaggedFloat # TODO(fangjun): move it to k2.ragged
from _k2 import RaggedInt # TODO(fangjun): move it to k2.ragged
from _k2 import simple_ragged_index_select

Expand Down
12 changes: 10 additions & 2 deletions k2/python/k2/fsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def clone(self) -> 'Fsa':
setattr(ans, name, value.clone())

for name, value in self.named_non_tensor_attr():
# Caution: We are not using `deepcopy` for `value`!
setattr(ans, name, value)

# Just copy elements of the _cache that we might already have..
Expand All @@ -998,17 +999,24 @@ def detach(self) -> 'Fsa':
ans = Fsa(self.arcs, properties=self.properties)

for name, value in self.named_tensor_attr(include_scores=False):
setattr(ans, name, value.detach())
if isinstance(value, torch.Tensor):
setattr(ans, name, value.detach())
else:
assert isinstance(value, k2.RaggedInt)
# For ragged tensors, they are copied over.
# Caution: Deep copy is not used!
setattr(ans, name, value)

for name, value in self.named_non_tensor_attr():
# Caution: We are not using `deepcopy` for `value`!
setattr(ans, name, value)

# Just copy elements of the _cache that we might already have..
# These don't directly participate in autograd, and are not supposed to
# be modified by the user, so this should be safe (i.e. it should
# be safe to do this without clone(); these are mostly not tensors
# anyway.
for name, value in self._cache:
for name, value in self._cache.items():
ans._cache[name] = value
return ans

Expand Down
44 changes: 34 additions & 10 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,10 @@ def compose(a_fsa: Fsa,
input FSAs do not need to be arc sorted.
Note:
`a_fsa.aux_labels` is required to be defined.
`a_fsa.aux_labels` is required to be defined and it can be either
a `torch.Tensor` or a ragged tensor of type `k2.RaggedInt`.
If it is a ragged tensor, then it requires that a_fsa.requires_grad is
False.
For both FSAs, the `aux_labels` attribute is interpreted as output labels,
(olabels), and the composition involves matching the olabels of a_fsa with
Expand Down Expand Up @@ -329,17 +332,24 @@ def compose(a_fsa: Fsa,
The result of composing a_fsa and b_fsa. `len(out_fsa.shape)` is 2
if and only if the two input FSAs are single FSAs;
otherwise, `len(out_fsa.shape)` is 3.
'''
assert hasattr(a_fsa, 'aux_labels')

assert isinstance(a_fsa.aux_labels, torch.Tensor)
if a_fsa.requires_grad:
assert isinstance(a_fsa.aux_labels, torch.Tensor)
a_fsa_inv = a_fsa.invert()
else:
# k2.invert() does not support autograd.
# The current use case is for decoding, which does not need autograd.
# We may extend it to support autograd if needed.
a_fsa_inv = invert(a_fsa)

a_fsa_inv = a_fsa.invert()
if treat_epsilons_specially is True or a_fsa_inv.is_cpu():
# the GPU version does not need to sort the input FSA
a_fsa_inv = arc_sort(a_fsa_inv)

if treat_epsilons_specially is True or b_fsa.is_cpu():
# the GPU version does not need to sort the input FSA
assert b_fsa.properties & fsa_properties.ARC_SORTED != 0

need_arc_map = True
Expand All @@ -350,6 +360,7 @@ def compose(a_fsa: Fsa,
out_fsa = Fsa(ragged_arc)
if inner_labels is not None:
# out_fsa.`inner_labels` = out_fsa.labels.clone()
# need a clone here since `Fsa.labels` is a reference
setattr(out_fsa, inner_labels, out_fsa.labels.clone())

if hasattr(b_fsa, 'aux_labels'):
Expand All @@ -358,7 +369,17 @@ def compose(a_fsa: Fsa,
# need a clone here since `Fsa.labels` is a reference
out_fsa.aux_labels = out_fsa.labels.clone()

out_fsa.labels = index(a_fsa_inv.aux_labels, a_arc_map)
if isinstance(a_fsa_inv.aux_labels, torch.Tensor):
out_fsa.labels = index(a_fsa_inv.aux_labels, a_arc_map)
else:
assert isinstance(a_fsa_inv.aux_labels, k2.RaggedInt)
# Refer to the following URLs for an example:
# a_fsa: https://git.io/Jqbob
# b_fsa: https://git.io/JqbKL
# out_fsa: https://git.io/JqbK3
out_fsa.labels = out_fsa.aux_labels
out_fsa.aux_labels = index(a_fsa_inv.aux_labels, a_arc_map)
out_fsa = invert(out_fsa)

for name, a_value in a_fsa_inv.named_tensor_attr():
if name in ('aux_labels', inner_labels):
Expand Down Expand Up @@ -673,16 +694,19 @@ def invert(fsa: Fsa) -> Fsa:
Returns:
The inverted Fsa, it's top-sorted if `fsa` is top-sorted.
'''
assert fsa.is_cpu()
# FIXME(fangjun): support autograd and update k2.compose.
assert fsa.requires_grad is False
if isinstance(fsa.aux_labels, torch.Tensor):
return fsa.invert()
else:
assert isinstance(fsa.aux_labels, k2.RaggedInt)
need_arc_map = False
ragged_arc, aux_labels, _ = _k2.invert(fsa.arcs, fsa.aux_labels,
need_arc_map)
return Fsa(ragged_arc, aux_labels)
need_arc_map = True
ragged_arc, aux_labels, arc_map = _k2.invert(fsa.arcs, fsa.aux_labels,
need_arc_map)
out_fsa = k2.utils.fsa_from_unary_function_tensor(
fsa, ragged_arc, arc_map)
out_fsa.aux_labels = aux_labels
return out_fsa


def random_paths(fsas: Fsa, use_double_scores: bool,
Expand Down
1 change: 1 addition & 0 deletions k2/python/k2/ragged/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# please sort imported functions alphabetically
from .autograd import normalize_scores
from .ops import append
from .ops import argmax_per_sublist
from .ops import create_ragged2
from .ops import get_layer
from .ops import index
Expand Down
44 changes: 42 additions & 2 deletions k2/python/k2/ragged/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def get_layer(src: _k2.RaggedShape, layer: int) -> _k2.RaggedShape:
return _k2.get_layer(src, layer)


def unique_sequences(src: _k2.RaggedInt, need_num_repeats: bool = True
def unique_sequences(src: _k2.RaggedInt,
need_num_repeats: bool = True,
need_new2old_indexes: bool = False
) -> Tuple[_k2.RaggedInt, Optional[_k2.RaggedInt]]: # noqa
'''Remove repeated sequences.
Expand All @@ -203,12 +205,27 @@ def unique_sequences(src: _k2.RaggedInt, need_num_repeats: bool = True
If several sequences have the same hash, only one of them is kept, even
if the actual content in the sequence is different.
Caution:
Even if there are no repeated sequences, the output may be different
from `src`. That is, `new2old_indexes` may NOT be an identity map even if
nothing was removed.
Args:
src:
The input ragged tensor. Must have `src.num_axes() == 2`
or `src_num_axes() == 3`
need_num_repeats:
If True, it also returns the number of repeats of each sequence.
need_new2old_indexes:
If true, it returns an extra 1-D tensor `new2old_indexes`.
If `src` has 2 axes, this tensor contains `src_idx0`;
if `src` has 3 axes, this tensor contains `src_idx01`.
Caution:
For repeated sublists, only one of them is kept.
The choice of which one to keep is **deterministic** and
is an implementation detail.
Returns:
Returns a tuple containing:
Expand All @@ -223,8 +240,13 @@ def unique_sequences(src: _k2.RaggedInt, need_num_repeats: bool = True
num_repeats.num_elements() == ans.dim0().
If ans.num_axes() is 3, then num_repeats.dim0() == ans.dim0() and
num_repeats.num_elements() == ans.tot_size(1).
- new2old_indexes: A 1-D tensor whose i-th element specifies the
input sublist that the i-th output sublist corresponds to.
'''
return _k2.unique_sequences(src, need_num_repeats=need_num_repeats)
return _k2.unique_sequences(src,
need_num_repeats=need_num_repeats,
need_new2old_indexes=need_new2old_indexes)


def regular_ragged_shape(dim0: int, dim1: int) -> _k2.RaggedShape:
Expand All @@ -245,3 +267,21 @@ def regular_ragged_shape(dim0: int, dim1: int) -> _k2.RaggedShape:
Return a ragged shape with 2 axes.
'''
return _k2.regular_ragged_shape(dim0, dim1)


def argmax_per_sublist(src: _k2.RaggedFloat,
initial_value: float = torch.finfo(torch.float32).min
) -> torch.Tensor: # noqa
'''Compute the argmax per sublist for a ragged tensor.
The argmax is computed on the last axis.
Args:
src:
The input ragged tensor.
initial_value:
The initial value used to compute the argmax.
Returns:
Return a 1-D tensor with dtype torch.int32.
'''
return _k2.argmax_per_sublist(src, initial_value)
Loading

0 comments on commit b33eab3

Please sign in to comment.