Skip to content

Commit

Permalink
Add numRowRes Argument to InnerJoin and SemiJoin
Browse files Browse the repository at this point in the history
- Introduced `numRowRes` as a parameter for `InnerJoin` and `SemiJoin` kernel functions, indicating the size of the result.
- In `InnerJoin`:
  - If `numRowRes` is -1, the result size is set to `numRowRhs * numRowLhs`.
  - Otherwise, the result size is determined by `numRowRes`.
- In `SemiJoin`:
  - If `numRowRes` is -1, the result size defaults to `numRowLhs`.
  - Otherwise, the result size is determined by `numRowRes`.
- Updated DaphneDSL:
  - Added `numRowRes` as an optional parameter for `innerJoin` and `semiJoin` built-in functions.
  - If not provided, `numRowRes` defaults to -1, which is passed to DaphneIR operations.
- Modified DaphneIR:
  - Made `numRowRes` a mandatory argument for `InnerJoinOp` and `SemiJoinOp`.
- Implementation Updates:
  - Updated `DaphneDSLBuiltins.cpp` to handle default `numRowRes` values.
  - Set `numRowRes` to -1 in `SQLVisitor.cpp` for compatibility.
  - Adjusted `kernels.json` to reflect the new parameter in `innerJoin` and `semiJoin`.
- Added script-level test cases to validate the new functionality.
- Addresses issue `daphne-eu#901` by allowing users to specify result size to prevent over-allocation.
  • Loading branch information
saminbassiri committed Nov 18, 2024
1 parent 576bde3 commit e814a8d
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def Daphne_InnerJoinOp : Daphne_Op<"innerJoin", [
DataTypeFrm, ValueTypesConcat,
DeclareOpInterfaceMethods<InferFrameLabelsOpInterface>,
]> {
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn);
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes);
let results = (outs FrameOrU:$res);
}

Expand Down Expand Up @@ -1120,7 +1120,7 @@ def Daphne_SemiJoinOp : Daphne_Op<"semiJoin", [
DeclareOpInterfaceMethods<InferTypesOpInterface>,
NumColsFromArg
]> {
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn);
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes);
let results = (outs FrameOrU:$res, MatrixOf<[Size]>:$lhsTids);
}

Expand Down
18 changes: 14 additions & 4 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,13 +993,18 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
builder.create<CartesianOp>(loc, FrameType::get(builder.getContext(), colTypes), args[0], args[1]));
}
if (func == "innerJoin") {
checkNumArgsExact(loc, func, numArgs, 4);
checkNumArgsMin(loc, func, numArgs, 4);
std::vector<mlir::Type> colTypes;
mlir::Value numRowRes;
for (int i = 0; i < 2; i++)
for (mlir::Type t : args[i].getType().dyn_cast<FrameType>().getColumnTypes())
colTypes.push_back(t);
if (numArgs == 5)
numRowRes = utils.castSI64If(args[4]);
else
numRowRes = builder.create<ConstantOp>(loc, int64_t(-1));
return static_cast<mlir::Value>(builder.create<InnerJoinOp>(loc, FrameType::get(builder.getContext(), colTypes),
args[0], args[1], args[2], args[3]));
args[0], args[1], args[2], args[3], numRowRes));
}
if (func == "fullOuterJoin")
return createJoinOp<FullOuterJoinOp>(loc, func, args);
Expand All @@ -1011,14 +1016,19 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
// TODO Reconcile this with the other join ops, but we need it to work
// quickly now.
// return createJoinOp<SemiJoinOp>(loc, func, args);
checkNumArgsExact(loc, func, numArgs, 4);
checkNumArgsMin(loc, func, numArgs, 4);
mlir::Value lhs = args[0];
mlir::Value rhs = args[1];
mlir::Value lhsOn = args[2];
mlir::Value rhsOn = args[3];
mlir::Value numRowRes;
if (numArgs == 5)
numRowRes = utils.castSI64If(args[4]);
else
numRowRes = builder.create<ConstantOp>(loc, int64_t(-1));
return builder
.create<SemiJoinOp>(loc, FrameType::get(builder.getContext(), {utils.unknownType}), utils.matrixOfSizeType,
lhs, rhs, lhsOn, rhsOn)
lhs, rhs, lhsOn, rhsOn, numRowRes)
.getResults();
}
if (func == "groupJoin") {
Expand Down
5 changes: 4 additions & 1 deletion src/parser/sql/SQLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,11 @@ antlrcpp::Any SQLVisitor::visitInnerJoin(SQLGrammarParser::InnerJoinContext *ctx
mlir::Value rhsName = valueOrErrorOnVisit(ctx->rhs);
mlir::Value lhsName = valueOrErrorOnVisit(ctx->lhs);

mlir::Value numRowRes =
static_cast<mlir::Value>(builder.create<mlir::daphne::ConstantOp>(queryLoc, static_cast<int64_t>(-1)));

return static_cast<mlir::Value>(
builder.create<mlir::daphne::InnerJoinOp>(loc, t, currentFrame, tojoin, rhsName, lhsName));
builder.create<mlir::daphne::InnerJoinOp>(loc, t, currentFrame, tojoin, rhsName, lhsName, numRowRes));
}

std::vector<mlir::Value> rhsNames;
Expand Down
13 changes: 12 additions & 1 deletion src/runtime/local/kernels/InnerJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,19 @@ inline void innerJoin(
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {

// Find out the value types of the columns to process.
ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn);
ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn);

// Perhaps check if res already allocated.
const size_t numRowRhs = rhs->getNumRows();
const size_t numRowLhs = lhs->getNumRows();
const size_t totalRows = numRowRhs * numRowLhs;
const size_t totalRows = numRowRes == -1 ? numRowRhs * numRowLhs : numRowRes;
const size_t numColRhs = rhs->getNumCols();
const size_t numColLhs = lhs->getNumCols();
const size_t totalCols = numColRhs + numColLhs;
Expand Down Expand Up @@ -126,15 +129,23 @@ inline void innerJoin(
row_idx_r, ctx);
hit = hit || innerJoinProbeIf<double, double>(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l,
row_idx_r, ctx);
hit = hit || innerJoinProbeIf<std::string, std::string>(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn,
row_idx_l, row_idx_r, ctx);
if (hit) {
for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) {
innerJoinSet<std::string>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<int64_t>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<double>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
col_idx_res++;
}
for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) {

innerJoinSet<std::string>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

innerJoinSet<int64_t>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

Expand Down
21 changes: 15 additions & 6 deletions src/runtime/local/kernels/SemiJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ void semiJoinCol(
Frame *&res, DenseMatrix<VTTid> *&resLhsTid,
// arguments
const DenseMatrix<VTLhs> *argLhs, const DenseMatrix<VTRhs> *argRhs,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {
if (argLhs->getNumCols() != 1)
Expand All @@ -72,11 +74,14 @@ void semiJoinCol(
// Create the output data objects.
if (res == nullptr) {
ValueTypeCode schema[] = {ValueTypeUtils::codeFor<VTLhs>};
res = DataObjectFactory::create<Frame>(numArgLhs, 1, schema, nullptr, false);
const size_t resSize = numRowRes == -1 ? numArgLhs : numRowRes;
res = DataObjectFactory::create<Frame>(resSize, 1, schema, nullptr, false);
}
auto resLhs = res->getColumn<VTLhs>(0);
if (resLhsTid == nullptr)
resLhsTid = DataObjectFactory::create<DenseMatrix<VTTid>>(numArgLhs, 1, false);
if (resLhsTid == nullptr) {
const size_t resLhsTidSize = numRowRes == -1 ? numArgLhs : numRowRes;
resLhsTid = DataObjectFactory::create<DenseMatrix<VTTid>>(resLhsTidSize, 1, false);
}

size_t pos = 0;
for (size_t i = 0; i < numArgLhs; i++) {
Expand Down Expand Up @@ -107,11 +112,13 @@ void semiJoinColIf(
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {
if (vtcLhs == ValueTypeUtils::codeFor<VTLhs> && vtcRhs == ValueTypeUtils::codeFor<VTRhs>) {
semiJoinCol<VTLhs, VTRhs, VTTid>(res, resLhsTid, lhs->getColumn<VTLhs>(lhsOn), rhs->getColumn<VTRhs>(rhsOn),
ctx);
numRowRes, ctx);
}
}

Expand All @@ -127,6 +134,8 @@ void semiJoin(
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {
// Find out the value types of the columns to process.
Expand All @@ -136,8 +145,8 @@ void semiJoin(
// Call the semiJoin-kernel on columns for the actual combination of
// value types.
// Repeat this for all type combinations...
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx);
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx);
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx);
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx);

// Set the column labels of the result frame.
std::string labels[] = {lhsOn};
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -2974,7 +2974,11 @@
{
"type": "const char *",
"name": "rhsOn"
}
},
{
"type": "int64_t",
"name": "numRowRes"
}
]
},
"instantiations": [[]]
Expand Down Expand Up @@ -4267,6 +4271,10 @@
{
"type": "const char *",
"name": "rhsOn"
},
{
"type": "int64_t",
"name": "numRowRes"
}
]
},
Expand Down
2 changes: 2 additions & 0 deletions test/api/cli/operations/OperationsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ MAKE_TEST_CASE("fill", 1)
MAKE_TEST_CASE("gemv", 1)
MAKE_TEST_CASE("idxMax", 1)
MAKE_TEST_CASE("idxMin", 1)
MAKE_TEST_CASE("innerJoin", 1)
MAKE_TEST_CASE("isNan", 1)
MAKE_TEST_CASE("lower", 1)
MAKE_TEST_CASE("mean", 1)
Expand All @@ -59,6 +60,7 @@ MAKE_TEST_CASE("rbind", 1)
MAKE_TEST_CASE("recode", 4)
MAKE_TEST_CASE("replace", 1)
MAKE_TEST_CASE("reverse", 1)
MAKE_TEST_CASE("semiJoin", 1)
MAKE_TEST_CASE("seq", 2)
MAKE_TEST_CASE("solve", 1)
MAKE_TEST_CASE("sqrt", 1)
Expand Down
14 changes: 14 additions & 0 deletions test/api/cli/operations/innerJoin_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# test inner join with optional arg for result size
f1 = createFrame(
[1, 2], [3, 4],
"a", "b"
);
f2 = createFrame(
[3, 4, 5], [6, 7, 8],
"c", "d"
);

f3 = innerJoin(f1, f2, "b", "c");
f4 = innerJoin(f1, f2, "b", "c", 2);
print(f3);
print(f4);
6 changes: 6 additions & 0 deletions test/api/cli/operations/innerJoin_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t])
1 3 3 6
2 4 4 7
Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t])
1 3 3 6
2 4 4 7
8 changes: 8 additions & 0 deletions test/api/cli/operations/semiJoin_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#test inner join with optional arg for result size
f1 = createFrame([ 1, 2 ], [ 3, 4 ], "a", "b");
f2 = createFrame([ 3, 4, 5 ], [ 6, 7, 8 ], "c", "d");

keys1, tids1 = semiJoin(f1, f2, "b", "c");
keys2, tids2 = semiJoin(f1, f2, "b", "c", 2);
print(f1[tids1, ]);
print(f1[tids2, ]);
6 changes: 6 additions & 0 deletions test/api/cli/operations/semiJoin_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Frame(2x2, [a:int64_t, b:int64_t])
1 3
2 4
Frame(2x2, [a:int64_t, b:int64_t])
1 3
2 4
4 changes: 2 additions & 2 deletions test/runtime/local/kernels/InnerJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

#include <cstdint>

TEST_CASE("innerJoin", TAG_KERNELS) {
TEST_CASE("InnerJoin", TAG_KERNELS) {
auto lhsC0 = genGivenVals<DenseMatrix<int64_t>>(4, {1, 2, 3, 4});
auto lhsC1 = genGivenVals<DenseMatrix<double>>(4, {11.0, 22.0, 33.0, 44.00});
std::vector<Structure *> lhsCols = {lhsC0, lhsC1};
Expand All @@ -46,7 +46,7 @@ TEST_CASE("innerJoin", TAG_KERNELS) {
auto rhs = DataObjectFactory::create<Frame>(rhsCols, rhsLabels);

Frame *res = nullptr;
innerJoin(res, lhs, rhs, "a", "c", nullptr);
innerJoin(res, lhs, rhs, "a", "c", -1, nullptr);

// Check the meta data.
CHECK(res->getNumRows() == 2);
Expand Down
2 changes: 1 addition & 1 deletion test/runtime/local/kernels/SemiJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TEST_CASE("SemiJoin", TAG_KERNELS) {
// res
Frame *res = nullptr;
DenseMatrix<int64_t> *lhsTid = nullptr;
semiJoin(res, lhsTid, lhs, rhs, "a", "c", nullptr);
semiJoin(res, lhsTid, lhs, rhs, "a", "c", -1, nullptr);

CHECK(*res == *expRes);
CHECK(*lhsTid == *expTid);
Expand Down

0 comments on commit e814a8d

Please sign in to comment.