diff --git a/src/data/CassandraBackend.hpp b/src/data/CassandraBackend.hpp index 03bc7b007..fde981bd6 100644 --- a/src/data/CassandraBackend.hpp +++ b/src/data/CassandraBackend.hpp @@ -909,7 +909,7 @@ class BasicCassandraBackend : public BackendInterface { } } - executor_.write(std::move(statements)); + executor_.writeEach(std::move(statements)); } void diff --git a/src/data/cassandra/impl/ExecutionStrategy.hpp b/src/data/cassandra/impl/ExecutionStrategy.hpp index 91ee8336e..b85e6d76e 100644 --- a/src/data/cassandra/impl/ExecutionStrategy.hpp +++ b/src/data/cassandra/impl/ExecutionStrategy.hpp @@ -35,6 +35,7 @@ #include #include +#include #include #include #include @@ -192,10 +193,24 @@ class DefaultExecutionStrategy { template void write(PreparedStatementType const& preparedStatement, Args&&... args) + { + auto statement = preparedStatement.bind(std::forward(args)...); + write(std::move(statement)); + } + + /** + * @brief Non-blocking query execution used for writing data. + * + * Retries forever with retry policy specified by @ref AsyncExecutor + * + * @param statement Statement to execute + * @throw DatabaseTimeout on timeout + */ + void + write(StatementType&& statement) { auto const startTime = std::chrono::steady_clock::now(); - auto statement = preparedStatement.bind(std::forward(args)...); incrementOutstandingRequestCount(); counters_->registerWriteStarted(); @@ -213,6 +228,24 @@ class DefaultExecutionStrategy { ); } + /** + * @brief Non-blocking query execution used for writing data. Constrast with write, this method does not execute + * the statements in a batch. + * + * Retries forever with retry policy specified by @ref AsyncExecutor. + * + * @param statements Vector of statements to execute + * @throw DatabaseTimeout on timeout + */ + void + writeEach(std::vector&& statements) + { + if (statements.empty()) + return; + + std::ranges::for_each(std::move(statements), [this](auto& statement) { this->write(std::move(statement)); }); + } + /** * @brief Non-blocking batched query execution used for writing data. * diff --git a/tests/unit/data/cassandra/ExecutionStrategyTests.cpp b/tests/unit/data/cassandra/ExecutionStrategyTests.cpp index 5aeacba6b..ea5ea7ceb 100644 --- a/tests/unit/data/cassandra/ExecutionStrategyTests.cpp +++ b/tests/unit/data/cassandra/ExecutionStrategyTests.cpp @@ -405,6 +405,47 @@ TEST_F(BackendCassandraExecutionStrategyTest, WriteMultipleAndCallSyncSucceeds) thread.join(); } +TEST_F(BackendCassandraExecutionStrategyTest, WriteEachAndCallSyncSucceeds) +{ + auto strat = makeStrategy(); + auto const totalRequests = 1024u; + auto const numStatements = 16u; + auto callCount = std::atomic_uint{0u}; + + auto work = std::optional{ctx}; + auto thread = std::thread{[this]() { ctx.run(); }}; + + ON_CALL(handle, asyncExecute(A(), A&&>())) + .WillByDefault([this, &callCount](auto const&, auto&& cb) { + // run on thread to emulate concurrency model of real asyncExecute + boost::asio::post(ctx, [&callCount, cb = std::forward(cb)] { + ++callCount; + cb({}); // pretend we got data + }); + return FakeFutureWithCallback{}; + }); + EXPECT_CALL( + handle, + asyncExecute( + A(), + A&&>() + ) + ) + .Times(totalRequests * numStatements); // numStatements per write call + EXPECT_CALL(*counters, registerWriteStarted()).Times(totalRequests * numStatements); + EXPECT_CALL(*counters, registerWriteFinished(testing::_)).Times(totalRequests * numStatements); + + auto makeStatements = [] { return std::vector(16); }; + for (auto i = 0u; i < totalRequests; ++i) + strat.writeEach(makeStatements()); + + strat.sync(); // make sure all above writes are finished + EXPECT_EQ(callCount, totalRequests * numStatements); // all requests should finish + + work.reset(); + thread.join(); +} + TEST_F(BackendCassandraExecutionStrategyTest, StatsCallsCountersReport) { auto strat = makeStrategy();