Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tiledb_field_get_nullable API #5378

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tiledb/api/c_api/query_field/query_field_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,31 @@ tiledb_query_field_handle_t::tiledb_query_field_handle_t(
if (field_name_ == tiledb::sm::constants::coords) {
field_origin_ = std::make_shared<FieldFromDimension>();
type_ = query_->array_schema().domain().dimension_ptr(0)->type();
is_nullable_ = false;
cell_val_num_ = 1;
} else if (field_name_ == tiledb::sm::constants::timestamps) {
field_origin_ = std::make_shared<FieldFromAttribute>();
type_ = tiledb::sm::constants::timestamp_type;
is_nullable_ = false;
cell_val_num_ = 1;
} else if (query_->array_schema().is_attr(field_name_)) {
field_origin_ = std::make_shared<FieldFromAttribute>();
type_ = query_->array_schema().attribute(field_name_)->type();
is_nullable_ = query_->array_schema().attribute(field_name_)->nullable();
cell_val_num_ =
query_->array_schema().attribute(field_name_)->cell_val_num();
} else if (query_->array_schema().is_dim(field_name_)) {
field_origin_ = std::make_shared<FieldFromDimension>();
type_ = query_->array_schema().dimension_ptr(field_name_)->type();
is_nullable_ = false;
cell_val_num_ =
query_->array_schema().dimension_ptr(field_name_)->cell_val_num();
} else if (query_->is_aggregate(field_name_)) {
is_aggregate = true;
field_origin_ = std::make_shared<FieldFromAggregate>();
auto aggregate = query_->get_aggregate(field_name_).value();
type_ = aggregate->output_datatype();
is_nullable_ = aggregate->aggregation_nullable();
cell_val_num_ =
aggregate->aggregation_var_sized() ? tiledb::sm::constants::var_num : 1;
} else {
Expand Down Expand Up @@ -152,6 +157,14 @@ capi_return_t tiledb_field_cell_val_num(
return TILEDB_OK;
}

capi_return_t tiledb_field_get_nullable(
tiledb_query_field_t* field, uint8_t* nullable) {
ensure_query_field_is_valid(field);
ensure_output_pointer_is_valid(nullable);
*nullable = field->is_nullable();
return TILEDB_OK;
}

capi_return_t tiledb_field_origin(
tiledb_query_field_t* field, tiledb_field_origin_t* origin) {
ensure_query_field_is_valid(field);
Expand Down Expand Up @@ -205,6 +218,15 @@ CAPI_INTERFACE(
ctx, field, cell_val_num);
}

CAPI_INTERFACE(
field_get_nullable,
tiledb_ctx_t* ctx,
tiledb_query_field_t* field,
uint8_t* nullable) {
return api_entry_context<tiledb::api::tiledb_field_get_nullable>(
ctx, field, nullable);
}

CAPI_INTERFACE(
field_origin,
tiledb_ctx_t* ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,26 @@ TILEDB_EXPORT capi_return_t tiledb_field_cell_val_num(
tiledb_query_field_t* field,
uint32_t* cell_val_num) TILEDB_NOEXCEPT;

/**
* Retrieves the nullability of a field.
*
* **Example:**
*
* @code{.c}
* uint8_t nullable;
* tiledb_field_get_nullable(ctx, attr, &nullable);
* @endcode
*
* @param[in] ctx The TileDB context.
* @param[in] field The query field handle
* @param[out] nullable Non-zero if `field` is nullable, and zero otherwise
* @return `TILEDB_OK` for success and `TILEDB_ERR` for error.
*/
TILEDB_EXPORT capi_return_t tiledb_field_get_nullable(
tiledb_ctx_t* ctx,
tiledb_query_field_t* field,
uint8_t* nullable) TILEDB_NOEXCEPT;
Comment on lines +155 to +158
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call it is_nullable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered that, in the end I chose to mimic tiledb_attribute_get_nullable


/**
* Get the origin type of the passed field
* **Example:**
Expand Down
4 changes: 4 additions & 0 deletions tiledb/api/c_api/query_field/query_field_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct tiledb_query_field_handle_t
std::shared_ptr<FieldOrigin> field_origin_;
tiledb::sm::Datatype type_;
uint32_t cell_val_num_;
bool is_nullable_;
std::shared_ptr<tiledb::sm::QueryChannel> channel_;

public:
Expand All @@ -96,6 +97,9 @@ struct tiledb_query_field_handle_t
uint32_t cell_val_num() {
return cell_val_num_;
}
bool is_nullable() const {
return is_nullable_;
}
tiledb_query_channel_handle_t* channel() {
return tiledb_query_channel_handle_t::make_handle(channel_);
}
Expand Down
106 changes: 95 additions & 11 deletions tiledb/api/c_api/query_field/test/unit_capi_query_field.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,31 @@ void QueryFieldFx::write_sparse_array(const std::string& array_name) {

throw_if_setup_failed(tiledb_query_set_layout(ctx, query, TILEDB_UNORDERED));

int32_t a[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int32_t b[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
uint64_t a_size = 10 * sizeof(int32_t);
uint64_t b_size = 10 * sizeof(int32_t);

int64_t d1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int64_t d2[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
uint64_t d1_size = 10 * sizeof(int64_t);
uint64_t d2_size = 10 * sizeof(int64_t);
int32_t a[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Copy link
Contributor Author

@rroelke rroelke Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last cell doesn't make it through the write (one of the other inputs only has nine cells) so I elected to remove it to avoid future confusion.

(e.g. a developer expecting a sum of 55 rather than 45, you can probably guess who did that)

int32_t b[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
uint64_t a_size = 9 * sizeof(int32_t);
uint64_t b_size = 9 * sizeof(int32_t);
uint8_t b_validity[] = {1, 1, 1, 1, 1, 1, 1, 1, 1};
uint64_t b_validity_size = sizeof(b_validity);

int64_t d1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
int64_t d2[] = {1, 1, 1, 1, 1, 1, 1, 1, 1};
uint64_t d1_size = 9 * sizeof(int64_t);
uint64_t d2_size = 9 * sizeof(int64_t);
char c_data[] = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
uint64_t c_size = strlen(c_data);
uint64_t c_data_offsets[] = {0, 5, 8, 13, 17, 21, 26, 31, 36, 40};
uint64_t c_offsets_size = sizeof(c_data_offsets);
uint8_t d_validity[] = {1, 1, 1, 1, 1, 1, 1, 1, 1};
uint64_t d_validity_size = sizeof(b_validity);

throw_if_setup_failed(
tiledb_query_set_data_buffer(ctx, query, "a", a, &a_size));

throw_if_setup_failed(
tiledb_query_set_data_buffer(ctx, query, "b", b, &b_size));
throw_if_setup_failed(tiledb_query_set_validity_buffer(
ctx, query, "b", b_validity, &b_validity_size));

throw_if_setup_failed(
tiledb_query_set_data_buffer(ctx, query, "d1", d1, &d1_size));
Expand All @@ -111,6 +117,8 @@ void QueryFieldFx::write_sparse_array(const std::string& array_name) {
tiledb_query_set_data_buffer(ctx, query, "d", c_data, &c_size));
throw_if_setup_failed(tiledb_query_set_offsets_buffer(
ctx, query, "d", c_data_offsets, &c_offsets_size));
throw_if_setup_failed(tiledb_query_set_validity_buffer(
ctx, query, "d", d_validity, &d_validity_size));

throw_if_setup_failed(tiledb_query_submit(ctx, query));

Expand Down Expand Up @@ -143,6 +151,7 @@ void QueryFieldFx::create_sparse_array(const std::string& array_name) {
throw_if_setup_failed(tiledb_attribute_alloc(ctx, "a", TILEDB_INT32, &a));
tiledb_attribute_t* b = nullptr;
throw_if_setup_failed(tiledb_attribute_alloc(ctx, "b", TILEDB_INT32, &b));
throw_if_setup_failed(tiledb_attribute_set_nullable(ctx, b, true));
tiledb_attribute_t* c = nullptr;
throw_if_setup_failed(
tiledb_attribute_alloc(ctx, "c", TILEDB_STRING_ASCII, &c));
Expand All @@ -153,6 +162,7 @@ void QueryFieldFx::create_sparse_array(const std::string& array_name) {
tiledb_attribute_alloc(ctx, "d", TILEDB_STRING_UTF8, &d));
throw_if_setup_failed(
tiledb_attribute_set_cell_val_num(ctx, d, TILEDB_VAR_NUM));
throw_if_setup_failed(tiledb_attribute_set_nullable(ctx, d, true));

// Create array schema
tiledb_array_schema_t* array_schema = nullptr;
Expand Down Expand Up @@ -301,6 +311,7 @@ TEST_CASE_METHOD(QueryFieldFx, "C API: get_field", "[capi][query_field]") {
tiledb_datatype_t type;
tiledb_field_origin_t origin;
uint32_t cell_val_num = 0;
uint8_t is_nullable = false;
tiledb_query_channel_t* channel = nullptr;

SECTION("Non-existent field") {
Expand All @@ -319,6 +330,8 @@ TEST_CASE_METHOD(QueryFieldFx, "C API: get_field", "[capi][query_field]") {
CHECK(origin == TILEDB_DIMENSION_FIELD);
REQUIRE(tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == 1);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(is_nullable == false);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}

Expand All @@ -333,6 +346,8 @@ TEST_CASE_METHOD(QueryFieldFx, "C API: get_field", "[capi][query_field]") {
CHECK(origin == TILEDB_ATTRIBUTE_FIELD);
REQUIRE(tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == 1);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(is_nullable == false);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}

Expand All @@ -346,10 +361,12 @@ TEST_CASE_METHOD(QueryFieldFx, "C API: get_field", "[capi][query_field]") {
CHECK(origin == TILEDB_DIMENSION_FIELD);
REQUIRE(tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == 1);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(is_nullable == false);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}

SECTION("Attribute field") {
SECTION("Non-nullable attribute field") {
// Check field api works on attribute field
REQUIRE(tiledb_query_get_field(ctx, query, "c", &field) == TILEDB_OK);
REQUIRE(tiledb_field_datatype(ctx, field, &type) == TILEDB_OK);
Expand All @@ -358,10 +375,75 @@ TEST_CASE_METHOD(QueryFieldFx, "C API: get_field", "[capi][query_field]") {
CHECK(origin == TILEDB_ATTRIBUTE_FIELD);
REQUIRE(tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == TILEDB_VAR_NUM);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(is_nullable == false);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}

SECTION("Aggregate field") {
SECTION("Nullablle attribute field") {
// Check field api works on attribute field
REQUIRE(tiledb_query_get_field(ctx, query, "d", &field) == TILEDB_OK);
REQUIRE(tiledb_field_datatype(ctx, field, &type) == TILEDB_OK);
CHECK(type == TILEDB_STRING_UTF8);
REQUIRE(tiledb_field_origin(ctx, field, &origin) == TILEDB_OK);
CHECK(origin == TILEDB_ATTRIBUTE_FIELD);
REQUIRE(tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == TILEDB_VAR_NUM);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(static_cast<bool>(is_nullable) == true);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}

SECTION("Aggregate field which might be nullable") {
auto expect_nullable = GENERATE(true, false);
auto attribute = (expect_nullable ? "b" : "a");
// Check field api works on aggregate field
const tiledb_channel_operator_t* operator_sum;
tiledb_channel_operation_t* sum_a;
REQUIRE(tiledb_channel_operator_sum_get(ctx, &operator_sum) == TILEDB_OK);
REQUIRE(
tiledb_create_unary_aggregate(
ctx, query, operator_sum, attribute, &sum_a) == TILEDB_OK);
REQUIRE(
tiledb_query_get_default_channel(ctx, query, &channel) == TILEDB_OK);
REQUIRE(
tiledb_channel_apply_aggregate(ctx, channel, "Sum", sum_a) ==
TILEDB_OK);

SECTION("validate") {
// Check field api works on aggregate field
REQUIRE(tiledb_query_get_field(ctx, query, "Sum", &field) == TILEDB_OK);
REQUIRE(tiledb_field_datatype(ctx, field, &type) == TILEDB_OK);
CHECK(type == TILEDB_INT64);
REQUIRE(tiledb_field_origin(ctx, field, &origin) == TILEDB_OK);
CHECK(origin == TILEDB_AGGREGATE_FIELD);
REQUIRE(
tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == 1);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(static_cast<bool>(is_nullable) == expect_nullable);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}
SECTION("run query") {
uint64_t sum = 0;
uint64_t size = 8;
uint8_t sum_validity = 0;
uint64_t validity_size = sizeof(uint8_t);
REQUIRE(
tiledb_query_set_data_buffer(ctx, query, "Sum", &sum, &size) ==
TILEDB_OK);
if (expect_nullable) {
REQUIRE(
tiledb_query_set_validity_buffer(
ctx, query, "Sum", &sum_validity, &validity_size) == TILEDB_OK);
}
REQUIRE(tiledb_query_submit(ctx, query) == TILEDB_OK);
CHECK(sum == 45);
}
CHECK(tiledb_query_channel_free(ctx, &channel) == TILEDB_OK);
}

SECTION("Non-nullable Aggregate field") {
// Check field api works on aggregate field
REQUIRE(
tiledb_query_get_default_channel(ctx, query, &channel) == TILEDB_OK);
Expand All @@ -378,6 +460,8 @@ TEST_CASE_METHOD(QueryFieldFx, "C API: get_field", "[capi][query_field]") {
REQUIRE(
tiledb_field_cell_val_num(ctx, field, &cell_val_num) == TILEDB_OK);
CHECK(cell_val_num == 1);
REQUIRE(tiledb_field_get_nullable(ctx, field, &is_nullable) == TILEDB_OK);
CHECK(is_nullable == false);
CHECK(tiledb_query_field_free(ctx, &field) == TILEDB_OK);
}
SECTION("run query") {
Expand Down
Loading