Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naming convention r_base::aggregate:: #169

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/include/rfuns_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ ScalarFunctionSet base_r_gte();

ScalarFunctionSet base_r_is_na();
ScalarFunctionSet base_r_as_integer();
ScalarFunctionSet base_r_as_numeric();

// sum
AggregateFunctionSet base_r_sum();
AggregateFunctionSet base_r_min();
AggregateFunctionSet base_r_max();
AggregateFunctionSet base_r_aggregate_sum();
AggregateFunctionSet base_r_aggregate_min();
AggregateFunctionSet base_r_aggregate_max();

ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) ;

Expand Down
204 changes: 152 additions & 52 deletions src/rfuns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <math.h>
#include <climits>
#include <cmath>

namespace duckdb {
namespace rfuns {
Expand All @@ -29,7 +30,7 @@ void BaseRAddFunctionDouble(DataChunk &args, ExpressionState &state, Vector &res

BinaryExecutor::ExecuteWithNulls<double, double, double>(
parts.lefts, parts.rights, result, args.size(), [&](double left, double right, ValidityMask &mask, idx_t idx) {
if (isnan(left) || isnan(right)) {
if (std::isnan(left) || std::isnan(right)) {
mask.SetInvalid(idx);
return 0.0;
}
Expand All @@ -38,7 +39,7 @@ void BaseRAddFunctionDouble(DataChunk &args, ExpressionState &state, Vector &res
}

double ExecuteBaseRPlusFunctionIntDouble(int32_t left, double right, ValidityMask &mask, idx_t idx) {
if (isnan(right)) {
if (std::isnan(right)) {
mask.SetInvalid(idx);
return 0.0;
}
Expand Down Expand Up @@ -86,81 +87,106 @@ ScalarFunctionSet base_r_add() {
#include <math.h>
#include <climits>
#include <limits>
#include <cmath>

namespace duckdb {
namespace rfuns {

namespace {

template <typename T>
int32_t check_range(T value, ValidityMask &mask, idx_t idx) {
int32_t check_int_range(T value, ValidityMask &mask, idx_t idx) {
if (value > std::numeric_limits<int32_t>::max() || value < std::numeric_limits<int32_t>::min() ) {
mask.SetInvalid(idx);
}

return static_cast<int32_t>(value);
}

template <typename T>
int32_t cast(T input, ValidityMask &mask, idx_t idx) {
return static_cast<int32_t>(input);
template <typename FROM, typename TO>
TO cast(FROM input, ValidityMask &mask, idx_t idx) {
return static_cast<TO>(input);
}

template <>
int32_t cast<double>(double input, ValidityMask &mask, idx_t idx) {
if (isnan(input)) {
int32_t cast<double, int32_t>(double input, ValidityMask &mask, idx_t idx) {
if (std::isnan(input)) {
mask.SetInvalid(idx);
}
return check_range(input, mask, idx);
return check_int_range(input, mask, idx);
}

template <>
int32_t cast<string_t>(string_t input, ValidityMask &mask, idx_t idx) {
double cast<string_t, double>(string_t input, ValidityMask &mask, idx_t idx) {
double result;
if (!TryDoubleCast<double>(input.GetData(), input.GetSize(), result, false)) {
mask.SetInvalid(idx);
}

return cast<double>(result, mask, idx);
return result;
}

template <>
int32_t cast<string_t, int32_t>(string_t input, ValidityMask &mask, idx_t idx) {
auto dbl = cast<string_t, double>(input, mask, idx);
return cast<double, int32_t>(dbl, mask, idx);
}

template <>
int32_t cast<date_t, int32_t>(date_t input, ValidityMask &mask, idx_t idx) {
return input.days;
}

template <>
int32_t cast<date_t>(date_t input, ValidityMask &mask, idx_t idx) {
double cast<date_t, double>(date_t input, ValidityMask &mask, idx_t idx) {
return input.days;
}

template <>
int32_t cast<timestamp_t>(timestamp_t input, ValidityMask &mask, idx_t idx) {
return check_range(Timestamp::GetEpochSeconds(input), mask, idx);
int32_t cast<timestamp_t, int32_t>(timestamp_t input, ValidityMask &mask, idx_t idx) {
return check_int_range(Timestamp::GetEpochSeconds(input), mask, idx);
}

template <LogicalTypeId TYPE>
ScalarFunction AsIntegerFunction() {
template <>
double cast<timestamp_t, double>(timestamp_t input, ValidityMask &mask, idx_t idx) {
return check_int_range(Timestamp::GetEpochSeconds(input), mask, idx);
}

template <LogicalTypeId TYPE, LogicalTypeId RESULT_TYPE>
ScalarFunction AsNumberFunction() {
using physical_type = typename physical<TYPE>::type;
using result_type = typename physical<RESULT_TYPE>::type;

auto fun = [](DataChunk &args, ExpressionState &state, Vector &result) {
UnaryExecutor::ExecuteWithNulls<physical_type, int32_t>(
args.data[0], result, args.size(), cast<physical_type>
UnaryExecutor::ExecuteWithNulls<physical_type, result_type>(
args.data[0], result, args.size(), cast<physical_type, result_type>
);
};
return ScalarFunction({TYPE}, LogicalType::INTEGER, fun);
return ScalarFunction({TYPE}, RESULT_TYPE, fun);
}

}
template <LogicalTypeId RESULT_TYPE>
ScalarFunctionSet as_number(std::string name) {
ScalarFunctionSet set(name);

ScalarFunctionSet base_r_as_integer() {
ScalarFunctionSet set("r_base::as.integer");
set.AddFunction(AsNumberFunction<LogicalType::BOOLEAN , RESULT_TYPE>());
set.AddFunction(AsNumberFunction<LogicalType::INTEGER , RESULT_TYPE>());
set.AddFunction(AsNumberFunction<LogicalType::DOUBLE , RESULT_TYPE>());
set.AddFunction(AsNumberFunction<LogicalType::VARCHAR , RESULT_TYPE>());
set.AddFunction(AsNumberFunction<LogicalType::DATE , RESULT_TYPE>());
set.AddFunction(AsNumberFunction<LogicalType::TIMESTAMP , RESULT_TYPE>());

set.AddFunction(AsIntegerFunction<LogicalType::BOOLEAN>());
set.AddFunction(AsIntegerFunction<LogicalType::INTEGER>());
set.AddFunction(AsIntegerFunction<LogicalType::DOUBLE>());
return set;
}

set.AddFunction(AsIntegerFunction<LogicalType::VARCHAR>());
}

set.AddFunction(AsIntegerFunction<LogicalType::DATE>());
set.AddFunction(AsIntegerFunction<LogicalType::TIMESTAMP>());
ScalarFunctionSet base_r_as_integer() {
return as_number<LogicalTypeId::INTEGER>("r_base::as.integer");
}

return set;
ScalarFunctionSet base_r_as_numeric() {
return as_number<LogicalTypeId::DOUBLE>("r_base::as.numeric");
}

}
Expand Down Expand Up @@ -206,19 +232,12 @@ ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) {
#include <math.h>
#include <climits>
#include <iostream>
#include <cmath>

namespace duckdb {
namespace rfuns {

void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
auto count = args.size();
auto input = args.data[0];
auto mask = FlatVector::Validity(input);
auto* data = FlatVector::GetData<double>(input);

result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<bool>(result);

void isna_double_loop(idx_t count, const double* data, bool* result_data, ValidityMask mask) {
idx_t base_idx = 0;
auto entry_count = ValidityMask::EntryCount(count);
for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) {
Expand Down Expand Up @@ -250,14 +269,52 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
}
}

void isna_any(DataChunk &args, ExpressionState &state, Vector &result) {
void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
auto count = args.size();
auto input = args.data[0];
auto mask = FlatVector::Validity(input);

result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<bool>(result);
switch(input.GetVectorType()) {
case VectorType::FLAT_VECTOR: {
result.SetVectorType(VectorType::FLAT_VECTOR);

isna_double_loop(
count,
FlatVector::GetData<double>(input),
FlatVector::GetData<bool>(result),
FlatVector::Validity(input)
);

break;
}

case VectorType::CONSTANT_VECTOR: {
result.SetVectorType(VectorType::CONSTANT_VECTOR);
auto result_data = ConstantVector::GetData<bool>(result);
auto ldata = ConstantVector::GetData<double>(input);

*result_data = ConstantVector::IsNull(input) || isnan(*ldata);

break;
}

default: {
UnifiedVectorFormat vdata;
input.ToUnifiedFormat(count, vdata);
result.SetVectorType(VectorType::FLAT_VECTOR);

isna_double_loop(
count,
UnifiedVectorFormat::GetData<double>(vdata),
FlatVector::GetData<bool>(result),
vdata.validity
);

break;
}
}
}

void isna_any_loop(idx_t count, bool* result_data, ValidityMask mask) {
if (mask.AllValid()) {
for (idx_t i = 0; i < count; i++) {
result_data[i] = false;
Expand Down Expand Up @@ -289,6 +346,47 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) {
}
}
}

}

void isna_any(DataChunk &args, ExpressionState &state, Vector &result) {
auto count = args.size();
auto input = args.data[0];

switch(input.GetVectorType()) {
case VectorType::FLAT_VECTOR: {
result.SetVectorType(VectorType::FLAT_VECTOR);
isna_any_loop(
count,
FlatVector::GetData<bool>(result),
FlatVector::Validity(input)
);

break;
}

case VectorType::CONSTANT_VECTOR: {
result.SetVectorType(VectorType::CONSTANT_VECTOR);
auto result_data = ConstantVector::GetData<bool>(result);
*result_data = ConstantVector::IsNull(input);

break;
}

default : {
UnifiedVectorFormat vdata;
input.ToUnifiedFormat(count, vdata);
result.SetVectorType(VectorType::FLAT_VECTOR);
isna_any_loop(
count,
FlatVector::GetData<bool>(result),
vdata.validity
);

break;
}
}

}


Expand Down Expand Up @@ -447,12 +545,12 @@ AggregateFunctionSet base_r_minmax(std::string name) {
return set;
}

AggregateFunctionSet base_r_min() {
return base_r_minmax<RMinOperation>("r_base::min");
AggregateFunctionSet base_r_aggregate_min() {
return base_r_minmax<RMinOperation>("r_base::aggregate::min");
}

AggregateFunctionSet base_r_max() {
return base_r_minmax<RMaxOperation>("r_base::max");
AggregateFunctionSet base_r_aggregate_max() {
return base_r_minmax<RMaxOperation>("r_base::aggregate::max");
}


Expand All @@ -464,6 +562,7 @@ AggregateFunctionSet base_r_max() {
#include <math.h>
#include <climits>
#include <iostream>
#include <cmath>

namespace duckdb {
namespace rfuns {
Expand Down Expand Up @@ -627,7 +726,7 @@ bool set_null(T value, ValidityMask &mask, idx_t idx) {

template <>
bool set_null<double>(double value, ValidityMask &mask, idx_t idx) {
if (isnan(value)) {
if (std::isnan(value)) {
mask.SetInvalid(idx);
return true;
}
Expand Down Expand Up @@ -772,10 +871,11 @@ static void register_rfuns(DatabaseInstance &instance) {

ExtensionUtil::RegisterFunction(instance, base_r_is_na());
ExtensionUtil::RegisterFunction(instance, base_r_as_integer());
ExtensionUtil::RegisterFunction(instance, base_r_as_numeric());

ExtensionUtil::RegisterFunction(instance, base_r_sum());
ExtensionUtil::RegisterFunction(instance, base_r_min());
ExtensionUtil::RegisterFunction(instance, base_r_max());
ExtensionUtil::RegisterFunction(instance, base_r_aggregate_sum());
ExtensionUtil::RegisterFunction(instance, base_r_aggregate_min());
ExtensionUtil::RegisterFunction(instance, base_r_aggregate_max());
}
} // namespace rfuns

Expand Down Expand Up @@ -926,8 +1026,8 @@ void add_RSum(AggregateFunctionSet& set, const LogicalType& type) {
));
}

AggregateFunctionSet base_r_sum() {
AggregateFunctionSet set("r_base::sum");
AggregateFunctionSet base_r_aggregate_sum() {
AggregateFunctionSet set("r_base::aggregate::sum");

add_RSum(set, LogicalType::BOOLEAN);
add_RSum(set, LogicalType::INTEGER);
Expand Down
Loading